Coverage priorities 1-5: test pure functions, extract interfaces, mock-based handler tests

- Priority 1: Test NewSendEmailTask + NewSendPushTask (5 tests)
- Priority 2: Test customHTTPErrorHandler — all 15+ branches (21 tests)
- Priority 3: Extract Enqueuer interface + payload builders in worker pkg (5 tests)
- Priority 4: Extract ClassifyFile/ComputeRelPath in migrate-encrypt (6 tests)
- Priority 5: Define Handler interfaces, refactor to accept them, mock-based tests (14 tests)
- Fix .gitignore: /worker instead of worker to stop ignoring internal/worker/

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey T
2026-04-01 20:30:09 -05:00
parent 00fd674b56
commit bec880886b
83 changed files with 19569 additions and 730 deletions

View File

@@ -12,7 +12,9 @@
"Bash(git add:*)",
"Bash(docker ps:*)",
"Bash(git commit:*)",
"Bash(git push:*)"
"Bash(git push:*)",
"Bash(docker info:*)",
"Bash(curl:*)"
]
},
"enableAllProjectMcpServers": true,

2
.gitignore vendored
View File

@@ -6,7 +6,7 @@
# Binaries
bin/
api
worker
/worker
/admin
!admin/
*.exe

View File

@@ -0,0 +1,61 @@
package main
import (
"testing"
"time"
)
func TestClassifyCompletion_CompletedAfterDue_ReturnsOverdue(t *testing.T) {
due := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC)
completed := time.Date(2025, 6, 5, 14, 0, 0, 0, time.UTC)
got := classifyCompletion(completed, due, 7)
if got != "overdue_tasks" {
t.Errorf("got %q, want overdue_tasks", got)
}
}
func TestClassifyCompletion_CompletedOnDueDate_ReturnsDueSoon(t *testing.T) {
due := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC)
completed := time.Date(2025, 6, 1, 10, 0, 0, 0, time.UTC)
got := classifyCompletion(completed, due, 7)
if got != "due_soon_tasks" {
t.Errorf("got %q, want due_soon_tasks", got)
}
}
func TestClassifyCompletion_CompletedWithinThreshold_ReturnsDueSoon(t *testing.T) {
due := time.Date(2025, 6, 10, 0, 0, 0, 0, time.UTC)
completed := time.Date(2025, 6, 5, 0, 0, 0, 0, time.UTC) // 5 days before due, threshold 7
got := classifyCompletion(completed, due, 7)
if got != "due_soon_tasks" {
t.Errorf("got %q, want due_soon_tasks", got)
}
}
func TestClassifyCompletion_CompletedAtExactThreshold_ReturnsDueSoon(t *testing.T) {
due := time.Date(2025, 6, 10, 0, 0, 0, 0, time.UTC)
completed := time.Date(2025, 6, 3, 0, 0, 0, 0, time.UTC) // exactly 7 days before due
got := classifyCompletion(completed, due, 7)
if got != "due_soon_tasks" {
t.Errorf("got %q, want due_soon_tasks", got)
}
}
func TestClassifyCompletion_CompletedBeyondThreshold_ReturnsUpcoming(t *testing.T) {
due := time.Date(2025, 6, 30, 0, 0, 0, 0, time.UTC)
completed := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC) // 29 days before due, threshold 7
got := classifyCompletion(completed, due, 7)
if got != "upcoming_tasks" {
t.Errorf("got %q, want upcoming_tasks", got)
}
}
func TestClassifyCompletion_TimeNormalization_SameDayDifferentTimes(t *testing.T) {
due := time.Date(2025, 6, 1, 23, 59, 59, 0, time.UTC)
completed := time.Date(2025, 6, 1, 0, 0, 1, 0, time.UTC) // same day, different times
got := classifyCompletion(completed, due, 7)
// Same day → daysBefore == 0 → within threshold → due_soon
if got != "due_soon_tasks" {
t.Errorf("got %q, want due_soon_tasks", got)
}
}

View File

@@ -0,0 +1,50 @@
package main
import (
"path/filepath"
"strings"
)
// isEncrypted checks if a file path ends with .enc
func isEncrypted(path string) bool {
return strings.HasSuffix(path, ".enc")
}
// encryptedPath appends .enc to the file path.
func encryptedPath(path string) string {
return path + ".enc"
}
// shouldProcessFile returns true if the file should be encrypted.
func shouldProcessFile(isDir bool, path string) bool {
return !isDir && !isEncrypted(path)
}
// FileAction represents the decision about what to do with a file during encryption migration.
type FileAction int
const (
ActionSkipDir FileAction = iota // Directory, skip
ActionSkipEncrypted // Already encrypted, skip
ActionDryRun // Would encrypt (dry run mode)
ActionEncrypt // Should encrypt
)
// ClassifyFile determines what action to take for a file during the walk.
func ClassifyFile(isDir bool, path string, dryRun bool) FileAction {
if isDir {
return ActionSkipDir
}
if isEncrypted(path) {
return ActionSkipEncrypted
}
if dryRun {
return ActionDryRun
}
return ActionEncrypt
}
// ComputeRelPath computes the relative path from base to path.
func ComputeRelPath(base, path string) (string, error) {
return filepath.Rel(base, path)
}

View File

@@ -0,0 +1,96 @@
package main
import "testing"
func TestIsEncrypted_EncFile_True(t *testing.T) {
if !isEncrypted("photo.jpg.enc") {
t.Error("expected true for .enc file")
}
}
func TestIsEncrypted_PdfFile_False(t *testing.T) {
if isEncrypted("doc.pdf") {
t.Error("expected false for .pdf file")
}
}
func TestIsEncrypted_DotEncOnly_True(t *testing.T) {
if !isEncrypted(".enc") {
t.Error("expected true for '.enc'")
}
}
func TestEncryptedPath_AppendsDotEnc(t *testing.T) {
got := encryptedPath("uploads/photo.jpg")
want := "uploads/photo.jpg.enc"
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestShouldProcessFile_RegularFile_True(t *testing.T) {
if !shouldProcessFile(false, "photo.jpg") {
t.Error("expected true for regular file")
}
}
func TestShouldProcessFile_Directory_False(t *testing.T) {
if shouldProcessFile(true, "uploads") {
t.Error("expected false for directory")
}
}
func TestShouldProcessFile_AlreadyEncrypted_False(t *testing.T) {
if shouldProcessFile(false, "photo.jpg.enc") {
t.Error("expected false for already encrypted file")
}
}
// --- ClassifyFile ---
func TestClassifyFile_Directory_SkipDir(t *testing.T) {
if got := ClassifyFile(true, "uploads", false); got != ActionSkipDir {
t.Errorf("got %d, want ActionSkipDir", got)
}
}
func TestClassifyFile_EncryptedFile_SkipEncrypted(t *testing.T) {
if got := ClassifyFile(false, "photo.jpg.enc", false); got != ActionSkipEncrypted {
t.Errorf("got %d, want ActionSkipEncrypted", got)
}
}
func TestClassifyFile_DryRun_DryRun(t *testing.T) {
if got := ClassifyFile(false, "photo.jpg", true); got != ActionDryRun {
t.Errorf("got %d, want ActionDryRun", got)
}
}
func TestClassifyFile_Normal_Encrypt(t *testing.T) {
if got := ClassifyFile(false, "photo.jpg", false); got != ActionEncrypt {
t.Errorf("got %d, want ActionEncrypt", got)
}
}
// --- ComputeRelPath ---
func TestComputeRelPath_Valid(t *testing.T) {
got, err := ComputeRelPath("/uploads", "/uploads/photo.jpg")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "photo.jpg" {
t.Errorf("got %q, want %q", got, "photo.jpg")
}
}
func TestComputeRelPath_NestedPath(t *testing.T) {
got, err := ComputeRelPath("/uploads", "/uploads/2024/01/photo.jpg")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "2024/01/photo.jpg"
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}

View File

@@ -13,7 +13,6 @@ import (
"flag"
"os"
"path/filepath"
"strings"
"time"
"github.com/rs/zerolog"
@@ -87,13 +86,11 @@ func main() {
return nil
}
// Skip directories
if info.IsDir() {
action := ClassifyFile(info.IsDir(), path, *dryRun)
switch action {
case ActionSkipDir:
return nil
}
// Skip files already encrypted
if strings.HasSuffix(path, ".enc") {
case ActionSkipEncrypted:
skipped++
return nil
}
@@ -101,14 +98,14 @@ func main() {
totalFiles++
// Compute the relative path from upload dir
relPath, err := filepath.Rel(absUploadDir, path)
relPath, err := ComputeRelPath(absUploadDir, path)
if err != nil {
log.Warn().Err(err).Str("path", path).Msg("Failed to compute relative path")
errCount++
return nil
}
if *dryRun {
if action == ActionDryRun {
log.Info().Str("file", relPath).Msg("[DRY RUN] Would encrypt")
return nil
}

24
cmd/worker/startup.go Normal file
View File

@@ -0,0 +1,24 @@
package main
import "github.com/treytartt/honeydue-api/internal/worker/jobs"
// queuePriorities returns the Asynq queue priority map.
func queuePriorities() map[string]int {
return map[string]int{
"critical": 6,
"default": 3,
"low": 1,
}
}
// allJobTypes returns all registered job type strings.
func allJobTypes() []string {
return []string{
jobs.TypeSmartReminder,
jobs.TypeDailyDigest,
jobs.TypeSendEmail,
jobs.TypeSendPush,
jobs.TypeOnboardingEmails,
jobs.TypeReminderLogCleanup,
}
}

View File

@@ -0,0 +1,45 @@
package main
import (
"testing"
)
func TestQueuePriorities_CriticalHighest(t *testing.T) {
p := queuePriorities()
if p["critical"] <= p["default"] || p["critical"] <= p["low"] {
t.Errorf("critical (%d) should be highest", p["critical"])
}
}
func TestQueuePriorities_ThreeQueues(t *testing.T) {
p := queuePriorities()
if len(p) != 3 {
t.Errorf("len = %d, want 3", len(p))
}
}
func TestAllJobTypes_Count(t *testing.T) {
types := allJobTypes()
if len(types) != 6 {
t.Errorf("len = %d, want 6", len(types))
}
}
func TestAllJobTypes_NoDuplicates(t *testing.T) {
types := allJobTypes()
seen := make(map[string]bool)
for _, typ := range types {
if seen[typ] {
t.Errorf("duplicate job type: %q", typ)
}
seen[typ] = true
}
}
func TestAllJobTypes_AllNonEmpty(t *testing.T) {
for _, typ := range allJobTypes() {
if typ == "" {
t.Error("found empty job type")
}
}
}

View File

@@ -2704,6 +2704,105 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
/auth/account/:
delete:
tags: [Authentication]
summary: Delete user account
description: Permanently deletes the authenticated user's account and all associated data
security:
- tokenAuth: []
requestBody:
content:
application/json:
schema:
type: object
properties:
password:
type: string
description: Required for email-auth users
confirmation:
type: string
description: Must be "DELETE" for social-auth users
responses:
'200':
description: Account deleted successfully
'400':
$ref: '#/components/responses/BadRequest'
'401':
$ref: '#/components/responses/Unauthorized'
/auth/refresh/:
post:
tags: [Authentication]
summary: Refresh auth token
description: Returns a new token if current token is in the renewal window (60-90 days old)
security:
- tokenAuth: []
responses:
'200':
description: Token refreshed
content:
application/json:
schema:
type: object
properties:
token:
type: string
'401':
$ref: '#/components/responses/Unauthorized'
/health/live:
get:
tags: [Health]
summary: Liveness probe
description: Simple liveness check, always returns 200
responses:
'200':
description: Alive
/tasks/suggestions/:
get:
tags: [Tasks]
summary: Get personalized task template suggestions
description: Returns task templates ranked by relevance to the residence's home profile
security:
- tokenAuth: []
parameters:
- name: residence_id
in: query
required: true
schema:
type: integer
responses:
'200':
description: Suggestions with relevance scores
content:
application/json:
schema:
type: object
properties:
suggestions:
type: array
items:
type: object
properties:
template:
$ref: '#/components/schemas/TaskTemplate'
relevance_score:
type: number
match_reasons:
type: array
items:
type: string
total_count:
type: integer
profile_completeness:
type: number
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
# =============================================================================
# Components
# =============================================================================

302
docs/server_2026_2_24.md Normal file
View File

@@ -0,0 +1,302 @@
# Casera Infrastructure Plan — February 2026
## Architecture Overview
```
┌─────────────┐
│ Cloudflare │
│ (CDN/DNS) │
└──────┬──────┘
│ HTTPS
┌──────┴──────┐
│ Hetzner LB │
│ ($5.99) │
└──────┬──────┘
┌────────────────┼────────────────┐
│ │ │
┌──────┴──────┐ ┌──────┴──────┐ ┌──────┴──────┐
│ CX33 #1 │ │ CX33 #2 │ │ CX33 #3 │
│ (manager) │ │ (manager) │ │ (manager) │
│ │ │ │ │ │
│ api (x2) │ │ api (x2) │ │ api (x1) │
│ admin │ │ worker │ │ worker │
│ redis │ │ dozzle │ │ │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
│ │ │
│ Docker Swarm Overlay (IPsec) │
└────────────────┼────────────────┘
┌────────────┼────────────────┐
│ │
┌──────┴──────┐ ┌───────┴──────┐
│ Neon │ │ Backblaze │
│ (Postgres) │ │ B2 │
│ Launch │ │ (media) │
└─────────────┘ └──────────────┘
```
## Swarm Nodes — Hetzner CX33
All 3 nodes are manager+worker (Raft consensus requires 3 managers for fault tolerance — 1 node can go down and the cluster stays operational).
| Spec | Value |
|------|-------|
| Plan | CX33 (Shared Regular Performance) |
| vCPU | 4 |
| RAM | 8 GB |
| Disk | 80 GB SSD |
| Traffic | 20 TB/mo included |
| Price | $6.59/mo per node |
| Region | Pick closest to users (US: Ashburn or Hillsboro, EU: Nuremberg/Falkenstein/Helsinki) |
**Why CX33 over CX23:** 8 GB RAM gives headroom for Redis, multiple API replicas, and the admin panel without pressure. The $2.50/mo difference per node isn't worth optimizing away.
### Container Distribution
| Container | Replicas | Notes |
|-----------|----------|-------|
| api | 3-6 | Spread across all nodes by Swarm |
| worker | 2-3 | Asynq workers pull jobs from Redis concurrently |
| admin | 1 | Next.js admin panel |
| redis | 1 | Pinned to one node with its volume |
| dozzle | 1 | Pinned to a manager node (needs Docker socket) |
### Scaling Path
- Need more capacity? Add another CX33 with `docker swarm join`. Swarm rebalances automatically.
- Need more API throughput? Bump replicas in the compose file. No infra change.
- Only infrastructure addition needed at scale: the Hetzner Load Balancer ($5.99/mo).
## Load Balancer — Hetzner LB
| Spec | Value |
|------|-------|
| Price | $5.99/mo |
| Purpose | Distribute traffic across Swarm nodes, TLS termination |
| When to add | When you need redundant ingress (not required day 1 if using Cloudflare to proxy to a single node) |
## Database — Neon Postgres (Launch Plan)
| Spec | Value |
|------|-------|
| Plan | Launch (usage-based, no monthly minimum) |
| Compute | $0.106/CU-hr, up to 16 CU (64 GB RAM) |
| Storage | $0.35/GB-month |
| Connections | Up to 10,000 via built-in PgBouncer |
| Typical cost | ~$5-15/mo for light load, ~$20-40/mo at 100k users |
| Free tier | Available for dev/staging (100 CU-hrs/mo, 0.5 GB) |
### Connection Pooling
Neon includes built-in PgBouncer on all plans. Enable by adding `-pooler` to the hostname:
```
# Direct connection
ep-cool-darkness-123456.us-east-2.aws.neon.tech
# Pooled connection (use this in production)
ep-cool-darkness-123456-pooler.us-east-2.aws.neon.tech
```
Runs in transaction mode — compatible with GORM out of the box.
### Configuration
```env
DB_HOST=ep-xxxxx-pooler.us-east-2.aws.neon.tech
DB_PORT=5432
DB_SSLMODE=require
POSTGRES_USER=<from neon dashboard>
POSTGRES_PASSWORD=<from neon dashboard>
POSTGRES_DB=casera
```
## Object Storage — Backblaze B2
| Spec | Value |
|------|-------|
| Storage | $6/TB/mo ($0.006/GB) |
| Egress | $0.01/GB (first 3x stored amount is free) |
| Free tier | 10 GB storage always free |
| API calls | Class A free, Class B/C free first 2,500/day |
| Spending cap | Built-in data caps with alerts at 75% and 100% |
### Bucket Setup
| Bucket | Visibility | Key Permissions | Contents |
|--------|------------|-----------------|----------|
| `casera-uploads` | Private | Read/Write (API containers) | User-uploaded photos, documents |
| `casera-certs` | Private | Read-only (API + worker) | APNs push certificates |
Serve files through the API using signed URLs — never expose buckets publicly.
### Why B2 Over Others
- **Spending cap**: only S3-compatible provider with built-in hard caps and alerts. No surprise bills.
- **Cheapest storage**: $6/TB vs Cloudflare R2 at $15/TB vs Tigris at $20/TB.
- **Free egress partner CDNs**: Cloudflare, Fastly, bunny.net — zero egress when behind Cloudflare.
## CDN — Cloudflare (Free Tier)
| Spec | Value |
|------|-------|
| Price | $0 |
| Purpose | DNS, CDN caching, DDoS protection, TLS termination |
| Setup | Point DNS to Cloudflare, proxy traffic to Hetzner LB (or directly to a Swarm node) |
Add this on day 1. No reason not to.
## Logging — Dozzle
| Spec | Value |
|------|-------|
| Price | $0 (open source) |
| Port | 9999 (internal only — do not expose publicly) |
| Features | Real-time log viewer, webhook support for alerts |
Runs as a container in the Swarm. Needs Docker socket access, so it's pinned to a manager node.
For 100k+ users, consider adding Prometheus + Grafana (self-hosted, free) or Betterstack (~$10/mo) for metrics and alerting beyond log viewing.
## Security
### Swarm Node Firewall (Hetzner Cloud Firewall — free)
| Port | Protocol | Source | Purpose |
|------|----------|--------|---------|
| Custom (e.g. 2222) | TCP | Your IP only | SSH |
| 80, 443 | TCP | Anywhere | Public traffic |
| 2377 | TCP | Swarm nodes only | Cluster management |
| 7946 | TCP/UDP | Swarm nodes only | Node discovery |
| 4789 | UDP | Swarm nodes only | Overlay network (VXLAN) |
| Everything else | — | — | Blocked |
Set up once in Hetzner dashboard, apply to all 3 nodes.
### SSH Hardening
```
# /etc/ssh/sshd_config
Port 2222 # Non-default port
PermitRootLogin no # No root SSH
PasswordAuthentication no # Key-only auth
PubkeyAuthentication yes
AllowUsers deploy # Only your deploy user
```
### Swarm ↔ Neon (Postgres)
| Layer | Method |
|-------|--------|
| Encryption | TLS enforced by Neon (`DB_SSLMODE=require`) |
| Authentication | Strong password stored as Docker secret |
| Access control | IP allowlist in Neon dashboard — restrict to 3 Swarm node IPs |
### Swarm ↔ B2 (Object Storage)
| Layer | Method |
|-------|--------|
| Encryption | HTTPS always (enforced by B2 API) |
| Authentication | Scoped application keys (not master key) |
| Access control | Per-bucket key permissions (read-only where possible) |
### Swarm Internal
| Layer | Method |
|-------|--------|
| Overlay encryption | `driver_opts: encrypted: "true"` on overlay network (IPsec between nodes) |
| Secrets | Use `docker secret create` for DB password, SECRET_KEY, B2 keys, APNs keys. Mounted at `/run/secrets/`, encrypted in Swarm raft log. |
| Container isolation | Non-root users in all containers (already configured in Dockerfile) |
### Docker Secrets Migration
Current setup uses environment variables for secrets. Migrate to Docker secrets for production:
```bash
# Create secrets
echo "your-db-password" | docker secret create postgres_password -
echo "your-secret-key" | docker secret create secret_key -
echo "your-b2-app-key" | docker secret create b2_app_key -
# Reference in compose file
services:
api:
secrets:
- postgres_password
- secret_key
secrets:
postgres_password:
external: true
secret_key:
external: true
```
Application code reads from `/run/secrets/<name>` instead of env vars.
## Redis (In-Cluster)
Redis stays inside the Swarm — no need to externalize.
| Purpose | Details |
|---------|---------|
| Asynq job queue | Background jobs: push notifications, digests, reminders, onboarding emails |
| Static data cache | Cached lookup tables with ETag support |
| Resource usage | ~20-50 MB RAM, negligible CPU |
At 100k users, Redis handles job queuing for nightly digests (100k enqueue + dequeue operations) without issue. A single Redis instance handles millions of operations per second.
Asynq coordinates multiple worker replicas automatically — each job is dequeued atomically by exactly one worker, no double-processing.
## Performance Estimates
| Metric | Value |
|--------|-------|
| Single CX33 API throughput | ~1,000-2,000 req/s (blended, with Neon latency) |
| 3-node cluster throughput | ~3,000-6,000 req/s |
| Avg requests per user per day | ~50 |
| Estimated user capacity (3 nodes) | ~200k-500k registered users |
| Bottleneck at scale | Neon compute tier, not Go or Swarm |
These are napkin estimates. Load test before launch.
## Monthly Cost Summary
### Starting Out
| Component | Provider | Cost |
|-----------|----------|------|
| 3x Swarm nodes | Hetzner CX33 | $19.77/mo |
| Postgres | Neon Launch | ~$5-15/mo |
| Object storage | Backblaze B2 | <$1/mo |
| CDN | Cloudflare Free | $0 |
| Logging | Dozzle (self-hosted) | $0 |
| **Total** | | **~$25-35/mo** |
### At Scale (100k users)
| Component | Provider | Cost |
|-----------|----------|------|
| 3x Swarm nodes | Hetzner CX33 | $19.77/mo |
| Load balancer | Hetzner LB | $5.99/mo |
| Postgres | Neon Launch | ~$20-40/mo |
| Object storage | Backblaze B2 | ~$1-3/mo |
| CDN | Cloudflare Free | $0 |
| Monitoring | Betterstack or self-hosted | ~$0-10/mo |
| **Total** | | **~$47-79/mo** |
## TODO
- [ ] Set up 3x Hetzner CX33 instances
- [ ] Initialize Docker Swarm (`docker swarm init` on first node, `docker swarm join` on others)
- [ ] Configure Hetzner Cloud Firewall
- [ ] Harden SSH on all nodes
- [ ] Create Neon project (Launch plan), configure IP allowlist
- [ ] Create Backblaze B2 buckets with scoped application keys
- [ ] Set up Cloudflare DNS proxying
- [ ] Update prod compose file: remove `db` service, add overlay encryption, add Docker secrets
- [ ] Add B2 SDK integration for file uploads (code change)
- [ ] Update config to read from `/run/secrets/` for Docker secrets
- [ ] Set B2 spending cap and alerts
- [ ] Load test the deployed stack
- [ ] Add Hetzner LB when needed

View File

@@ -0,0 +1,176 @@
package dto
import (
"testing"
)
// --- GetPage ---
func TestGetPage_Zero_Returns1(t *testing.T) {
p := &PaginationParams{Page: 0}
if got := p.GetPage(); got != 1 {
t.Errorf("GetPage(0) = %d, want 1", got)
}
}
func TestGetPage_Negative_Returns1(t *testing.T) {
p := &PaginationParams{Page: -5}
if got := p.GetPage(); got != 1 {
t.Errorf("GetPage(-5) = %d, want 1", got)
}
}
func TestGetPage_Valid_ReturnsValue(t *testing.T) {
p := &PaginationParams{Page: 3}
if got := p.GetPage(); got != 3 {
t.Errorf("GetPage(3) = %d, want 3", got)
}
}
// --- GetPerPage ---
func TestGetPerPage_Zero_Returns20(t *testing.T) {
p := &PaginationParams{PerPage: 0}
if got := p.GetPerPage(); got != 20 {
t.Errorf("GetPerPage(0) = %d, want 20", got)
}
}
func TestGetPerPage_Negative_Returns20(t *testing.T) {
p := &PaginationParams{PerPage: -1}
if got := p.GetPerPage(); got != 20 {
t.Errorf("GetPerPage(-1) = %d, want 20", got)
}
}
func TestGetPerPage_TooLarge_Returns10000(t *testing.T) {
p := &PaginationParams{PerPage: 20000}
if got := p.GetPerPage(); got != 10000 {
t.Errorf("GetPerPage(20000) = %d, want 10000", got)
}
}
func TestGetPerPage_Valid_ReturnsValue(t *testing.T) {
p := &PaginationParams{PerPage: 50}
if got := p.GetPerPage(); got != 50 {
t.Errorf("GetPerPage(50) = %d, want 50", got)
}
}
// --- GetOffset ---
func TestGetOffset_Page1_Returns0(t *testing.T) {
p := &PaginationParams{Page: 1, PerPage: 20}
if got := p.GetOffset(); got != 0 {
t.Errorf("GetOffset(page=1, perPage=20) = %d, want 0", got)
}
}
func TestGetOffset_Page3_PerPage10_Returns20(t *testing.T) {
p := &PaginationParams{Page: 3, PerPage: 10}
if got := p.GetOffset(); got != 20 {
t.Errorf("GetOffset(page=3, perPage=10) = %d, want 20", got)
}
}
func TestGetOffset_Defaults_Returns0(t *testing.T) {
p := &PaginationParams{}
if got := p.GetOffset(); got != 0 {
t.Errorf("GetOffset(defaults) = %d, want 0", got)
}
}
// --- GetSortDir ---
func TestGetSortDir_Asc(t *testing.T) {
p := &PaginationParams{SortDir: "asc"}
if got := p.GetSortDir(); got != "ASC" {
t.Errorf("GetSortDir('asc') = %q, want 'ASC'", got)
}
}
func TestGetSortDir_Desc(t *testing.T) {
p := &PaginationParams{SortDir: "desc"}
if got := p.GetSortDir(); got != "DESC" {
t.Errorf("GetSortDir('desc') = %q, want 'DESC'", got)
}
}
func TestGetSortDir_Empty_ReturnsDesc(t *testing.T) {
p := &PaginationParams{SortDir: ""}
if got := p.GetSortDir(); got != "DESC" {
t.Errorf("GetSortDir('') = %q, want 'DESC'", got)
}
}
func TestGetSortDir_Invalid_ReturnsDesc(t *testing.T) {
p := &PaginationParams{SortDir: "RANDOM"}
if got := p.GetSortDir(); got != "DESC" {
t.Errorf("GetSortDir('RANDOM') = %q, want 'DESC'", got)
}
}
// --- GetSafeSortBy ---
func TestGetSafeSortBy_Allowed(t *testing.T) {
p := &PaginationParams{SortBy: "name"}
got := p.GetSafeSortBy([]string{"name", "email"}, "id")
if got != "name" {
t.Errorf("GetSafeSortBy('name') = %q, want 'name'", got)
}
}
func TestGetSafeSortBy_NotAllowed_ReturnsDefault(t *testing.T) {
p := &PaginationParams{SortBy: "password"}
got := p.GetSafeSortBy([]string{"name", "email"}, "id")
if got != "id" {
t.Errorf("GetSafeSortBy('password') = %q, want 'id'", got)
}
}
func TestGetSafeSortBy_Empty_ReturnsDefault(t *testing.T) {
p := &PaginationParams{SortBy: ""}
got := p.GetSafeSortBy([]string{"name", "email"}, "id")
if got != "id" {
t.Errorf("GetSafeSortBy('') = %q, want 'id'", got)
}
}
// --- NewPaginatedResponse ---
func TestNewPaginatedResponse_ExactPages(t *testing.T) {
resp := NewPaginatedResponse([]string{"a", "b"}, 40, 1, 20)
if resp.TotalPages != 2 {
t.Errorf("TotalPages = %d, want 2", resp.TotalPages)
}
if resp.Total != 40 {
t.Errorf("Total = %d, want 40", resp.Total)
}
if resp.Page != 1 {
t.Errorf("Page = %d, want 1", resp.Page)
}
if resp.PerPage != 20 {
t.Errorf("PerPage = %d, want 20", resp.PerPage)
}
}
func TestNewPaginatedResponse_PartialLastPage(t *testing.T) {
resp := NewPaginatedResponse(nil, 21, 1, 20)
if resp.TotalPages != 2 {
t.Errorf("TotalPages = %d, want 2", resp.TotalPages)
}
}
func TestNewPaginatedResponse_SinglePage(t *testing.T) {
resp := NewPaginatedResponse(nil, 5, 1, 20)
if resp.TotalPages != 1 {
t.Errorf("TotalPages = %d, want 1", resp.TotalPages)
}
}
func TestNewPaginatedResponse_ZeroTotal(t *testing.T) {
resp := NewPaginatedResponse(nil, 0, 1, 20)
if resp.TotalPages != 0 {
t.Errorf("TotalPages = %d, want 0", resp.TotalPages)
}
}

View File

@@ -0,0 +1,109 @@
package apperrors
import (
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNotFound(t *testing.T) {
err := NotFound("error.task_not_found")
assert.Equal(t, http.StatusNotFound, err.Code)
assert.Equal(t, "error.task_not_found", err.MessageKey)
assert.Empty(t, err.Message)
assert.Nil(t, err.Err)
}
func TestForbidden(t *testing.T) {
err := Forbidden("error.residence_access_denied")
assert.Equal(t, http.StatusForbidden, err.Code)
assert.Equal(t, "error.residence_access_denied", err.MessageKey)
}
func TestBadRequest(t *testing.T) {
err := BadRequest("error.invalid_request_body")
assert.Equal(t, http.StatusBadRequest, err.Code)
assert.Equal(t, "error.invalid_request_body", err.MessageKey)
}
func TestUnauthorized(t *testing.T) {
err := Unauthorized("error.not_authenticated")
assert.Equal(t, http.StatusUnauthorized, err.Code)
assert.Equal(t, "error.not_authenticated", err.MessageKey)
}
func TestConflict(t *testing.T) {
err := Conflict("error.email_taken")
assert.Equal(t, http.StatusConflict, err.Code)
assert.Equal(t, "error.email_taken", err.MessageKey)
}
func TestTooManyRequests(t *testing.T) {
err := TooManyRequests("error.rate_limit_exceeded")
assert.Equal(t, http.StatusTooManyRequests, err.Code)
assert.Equal(t, "error.rate_limit_exceeded", err.MessageKey)
}
func TestInternal(t *testing.T) {
underlying := errors.New("database connection failed")
err := Internal(underlying)
assert.Equal(t, http.StatusInternalServerError, err.Code)
assert.Equal(t, "error.internal", err.MessageKey)
assert.Equal(t, underlying, err.Err)
}
func TestAppError_Error_WithWrappedError(t *testing.T) {
underlying := errors.New("connection refused")
err := Internal(underlying).WithMessage("database error")
assert.Equal(t, "database error: connection refused", err.Error())
}
func TestAppError_Error_WithMessageOnly(t *testing.T) {
err := NotFound("error.task_not_found").WithMessage("Task not found")
assert.Equal(t, "Task not found", err.Error())
}
func TestAppError_Error_MessageKeyFallback(t *testing.T) {
err := NotFound("error.task_not_found")
// No Message set, no Err set — should fall back to MessageKey
assert.Equal(t, "error.task_not_found", err.Error())
}
func TestAppError_Unwrap(t *testing.T) {
underlying := errors.New("wrapped error")
err := Internal(underlying)
assert.Equal(t, underlying, errors.Unwrap(err))
}
func TestAppError_Unwrap_Nil(t *testing.T) {
err := NotFound("error.task_not_found")
assert.Nil(t, errors.Unwrap(err))
}
func TestAppError_WithMessage(t *testing.T) {
err := NotFound("error.task_not_found").WithMessage("custom message")
assert.Equal(t, "custom message", err.Message)
assert.Equal(t, "error.task_not_found", err.MessageKey)
}
func TestAppError_Wrap(t *testing.T) {
underlying := errors.New("some error")
err := BadRequest("error.invalid_request_body").Wrap(underlying)
assert.Equal(t, underlying, err.Err)
assert.Equal(t, http.StatusBadRequest, err.Code)
}
func TestAppError_ImplementsError(t *testing.T) {
var err error = NotFound("error.task_not_found")
assert.NotNil(t, err)
assert.Equal(t, "error.task_not_found", err.Error())
}
func TestAppError_ErrorsAs(t *testing.T) {
var appErr *AppError
err := NotFound("error.task_not_found")
assert.True(t, errors.As(err, &appErr))
assert.Equal(t, http.StatusNotFound, appErr.Code)
}

View File

@@ -0,0 +1,324 @@
package config
import (
"sync"
"testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// resetConfigState resets the package-level singleton so each test starts fresh.
func resetConfigState() {
cfg = nil
cfgOnce = sync.Once{}
viper.Reset()
}
func TestLoad_DefaultValues(t *testing.T) {
resetConfigState()
// Provide required SECRET_KEY so validation passes
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
c, err := Load()
require.NoError(t, err)
// Server defaults
assert.Equal(t, 8000, c.Server.Port)
assert.False(t, c.Server.Debug)
assert.False(t, c.Server.DebugFixedCodes)
assert.Equal(t, "UTC", c.Server.Timezone)
assert.Equal(t, "/app/static", c.Server.StaticDir)
assert.Equal(t, "https://api.myhoneydue.com", c.Server.BaseURL)
// Database defaults
assert.Equal(t, "localhost", c.Database.Host)
assert.Equal(t, 5432, c.Database.Port)
assert.Equal(t, "postgres", c.Database.User)
assert.Equal(t, "honeydue", c.Database.Database)
assert.Equal(t, "disable", c.Database.SSLMode)
assert.Equal(t, 25, c.Database.MaxOpenConns)
assert.Equal(t, 10, c.Database.MaxIdleConns)
// Redis defaults
assert.Equal(t, "redis://localhost:6379/0", c.Redis.URL)
assert.Equal(t, 0, c.Redis.DB)
// Worker defaults
assert.Equal(t, 14, c.Worker.TaskReminderHour)
assert.Equal(t, 15, c.Worker.OverdueReminderHour)
assert.Equal(t, 3, c.Worker.DailyNotifHour)
// Token expiry defaults
assert.Equal(t, 90, c.Security.TokenExpiryDays)
assert.Equal(t, 60, c.Security.TokenRefreshDays)
// Feature flags default to true
assert.True(t, c.Features.PushEnabled)
assert.True(t, c.Features.EmailEnabled)
assert.True(t, c.Features.WebhooksEnabled)
assert.True(t, c.Features.OnboardingEmailsEnabled)
assert.True(t, c.Features.PDFReportsEnabled)
assert.True(t, c.Features.WorkerEnabled)
}
func TestLoad_EnvOverrides(t *testing.T) {
resetConfigState()
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
t.Setenv("PORT", "9090")
t.Setenv("DEBUG", "true")
t.Setenv("DB_HOST", "db.example.com")
t.Setenv("DB_PORT", "5433")
t.Setenv("TOKEN_EXPIRY_DAYS", "180")
t.Setenv("TOKEN_REFRESH_DAYS", "120")
t.Setenv("FEATURE_PUSH_ENABLED", "false")
c, err := Load()
require.NoError(t, err)
assert.Equal(t, 9090, c.Server.Port)
assert.True(t, c.Server.Debug)
assert.Equal(t, "db.example.com", c.Database.Host)
assert.Equal(t, 5433, c.Database.Port)
assert.Equal(t, 180, c.Security.TokenExpiryDays)
assert.Equal(t, 120, c.Security.TokenRefreshDays)
assert.False(t, c.Features.PushEnabled)
}
func TestLoad_Validation_MissingSecretKey_Production(t *testing.T) {
// Test validate() directly to avoid the sync.Once mutex issue
// that occurs when Load() resets cfgOnce inside cfgOnce.Do()
cfg := &Config{
Server: ServerConfig{Debug: false},
Security: SecurityConfig{SecretKey: ""},
}
err := validate(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "SECRET_KEY")
}
func TestLoad_Validation_MissingSecretKey_DebugMode(t *testing.T) {
resetConfigState()
t.Setenv("SECRET_KEY", "")
t.Setenv("DEBUG", "true")
c, err := Load()
require.NoError(t, err)
// In debug mode, a default key is assigned
assert.Equal(t, "change-me-in-production-secret-key-12345", c.Security.SecretKey)
}
func TestLoad_Validation_WeakSecretKey_Production(t *testing.T) {
// Test validate() directly to avoid the sync.Once mutex issue
cfg := &Config{
Server: ServerConfig{Debug: false},
Security: SecurityConfig{SecretKey: "password"},
}
err := validate(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "well-known weak value")
}
func TestLoad_Validation_WeakSecretKey_DebugMode(t *testing.T) {
resetConfigState()
t.Setenv("SECRET_KEY", "secret")
t.Setenv("DEBUG", "true")
// In debug mode, weak keys produce a warning but no error
c, err := Load()
require.NoError(t, err)
assert.Equal(t, "secret", c.Security.SecretKey)
}
func TestLoad_Validation_EncryptionKey_Valid(t *testing.T) {
resetConfigState()
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
// Valid 64-char hex key (32 bytes)
t.Setenv("STORAGE_ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
c, err := Load()
require.NoError(t, err)
assert.Equal(t, "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", c.Storage.EncryptionKey)
}
func TestLoad_Validation_EncryptionKey_WrongLength(t *testing.T) {
// Test validate() directly to avoid the sync.Once mutex issue
cfg := &Config{
Server: ServerConfig{Debug: false},
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
Storage: StorageConfig{EncryptionKey: "tooshort"},
}
err := validate(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "STORAGE_ENCRYPTION_KEY must be exactly 64 hex characters")
}
func TestLoad_Validation_EncryptionKey_InvalidHex(t *testing.T) {
// Test validate() directly to avoid the sync.Once mutex issue
cfg := &Config{
Server: ServerConfig{Debug: false},
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
Storage: StorageConfig{EncryptionKey: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"},
}
err := validate(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid hex")
}
func TestLoad_DatabaseURL_Override(t *testing.T) {
resetConfigState()
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
t.Setenv("DATABASE_URL", "postgres://myuser:mypass@dbhost:5433/mydb?sslmode=require")
c, err := Load()
require.NoError(t, err)
assert.Equal(t, "dbhost", c.Database.Host)
assert.Equal(t, 5433, c.Database.Port)
assert.Equal(t, "myuser", c.Database.User)
assert.Equal(t, "mypass", c.Database.Password)
assert.Equal(t, "mydb", c.Database.Database)
assert.Equal(t, "require", c.Database.SSLMode)
}
func TestDSN(t *testing.T) {
d := DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: "Password123",
Database: "testdb",
SSLMode: "disable",
}
dsn := d.DSN()
assert.Contains(t, dsn, "host=localhost")
assert.Contains(t, dsn, "port=5432")
assert.Contains(t, dsn, "user=testuser")
assert.Contains(t, dsn, "password=Password123")
assert.Contains(t, dsn, "dbname=testdb")
assert.Contains(t, dsn, "sslmode=disable")
}
func TestMaskURLCredentials(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "URL with password",
input: "postgres://user:secret@host:5432/db",
expected: "postgres://user:xxxxx@host:5432/db",
},
{
name: "URL without password",
input: "postgres://user@host:5432/db",
expected: "postgres://user@host:5432/db",
},
{
name: "URL without user info",
input: "postgres://host:5432/db",
expected: "postgres://host:5432/db",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := MaskURLCredentials(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
func TestParseCorsOrigins(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{"empty string", "", nil},
{"single origin", "https://example.com", []string{"https://example.com"}},
{"multiple origins", "https://a.com, https://b.com", []string{"https://a.com", "https://b.com"}},
{"whitespace trimmed", " https://a.com , https://b.com ", []string{"https://a.com", "https://b.com"}},
{"empty parts skipped", "https://a.com,,https://b.com", []string{"https://a.com", "https://b.com"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := parseCorsOrigins(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
func TestParseDatabaseURL(t *testing.T) {
tests := []struct {
name string
url string
wantHost string
wantPort int
wantUser string
wantPass string
wantDB string
wantSSL string
expectError bool
}{
{
name: "full URL",
url: "postgres://user:Password123@host:5433/mydb?sslmode=require",
wantHost: "host",
wantPort: 5433,
wantUser: "user",
wantPass: "Password123",
wantDB: "mydb",
wantSSL: "require",
},
{
name: "default port",
url: "postgres://user:pass@host/mydb",
wantHost: "host",
wantPort: 5432,
wantUser: "user",
wantPass: "pass",
wantDB: "mydb",
wantSSL: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := parseDatabaseURL(tc.url)
if tc.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tc.wantHost, result.Host)
assert.Equal(t, tc.wantPort, result.Port)
assert.Equal(t, tc.wantUser, result.User)
assert.Equal(t, tc.wantPass, result.Password)
assert.Equal(t, tc.wantDB, result.Database)
assert.Equal(t, tc.wantSSL, result.SSLMode)
})
}
}
func TestIsWeakSecretKey(t *testing.T) {
assert.True(t, isWeakSecretKey("secret"))
assert.True(t, isWeakSecretKey("Secret")) // case-insensitive
assert.True(t, isWeakSecretKey(" changeme ")) // whitespace trimmed
assert.True(t, isWeakSecretKey("password"))
assert.True(t, isWeakSecretKey("change-me"))
assert.False(t, isWeakSecretKey("a-strong-unique-production-key"))
}
func TestGet_ReturnsNilBeforeLoad(t *testing.T) {
resetConfigState()
assert.Nil(t, Get())
}

View File

@@ -0,0 +1,103 @@
package database
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// --- Unit tests for Paginate parameter clamping ---
func TestPaginate_PageZeroDefaultsToOne(t *testing.T) {
scope := Paginate(0, 10)
db := openTestDB(t)
createTestRows(t, db, 5)
var rows []testRow
err := db.Scopes(scope).Find(&rows).Error
require.NoError(t, err)
// page=0 normalised to page=1, pageSize=10 → should get all 5 rows
assert.Len(t, rows, 5)
}
func TestPaginate_PageSizeZeroDefaultsTo100(t *testing.T) {
scope := Paginate(1, 0)
db := openTestDB(t)
createTestRows(t, db, 5)
var rows []testRow
err := db.Scopes(scope).Find(&rows).Error
require.NoError(t, err)
// pageSize=0 normalised to 100, only 5 rows exist → 5 returned
assert.Len(t, rows, 5)
}
func TestPaginate_PageSizeOverMaxCappedAt1000(t *testing.T) {
scope := Paginate(1, 2000)
db := openTestDB(t)
createTestRows(t, db, 5)
var rows []testRow
err := db.Scopes(scope).Find(&rows).Error
require.NoError(t, err)
// pageSize=2000 capped to 1000, only 5 rows → 5 returned
assert.Len(t, rows, 5)
}
func TestPaginate_NormalValues(t *testing.T) {
scope := Paginate(1, 3)
db := openTestDB(t)
createTestRows(t, db, 10)
var rows []testRow
err := db.Scopes(scope).Order("id ASC").Find(&rows).Error
require.NoError(t, err)
assert.Len(t, rows, 3)
assert.Equal(t, "row_1", rows[0].Name)
assert.Equal(t, "row_3", rows[2].Name)
}
func TestPaginate_SQLiteIntegration_Page2Size10(t *testing.T) {
db := openTestDB(t)
createTestRows(t, db, 25)
scope := Paginate(2, 10)
var rows []testRow
err := db.Scopes(scope).Order("id ASC").Find(&rows).Error
require.NoError(t, err)
// Page 2 with size 10 → rows 11..20
assert.Len(t, rows, 10)
assert.Equal(t, "row_11", rows[0].Name)
assert.Equal(t, "row_20", rows[9].Name)
}
// --- helpers ---
type testRow struct {
ID uint `gorm:"primaryKey"`
Name string
}
func openTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&testRow{}))
return db
}
func createTestRows(t *testing.T, db *gorm.DB, n int) {
t.Helper()
for i := 1; i <= n; i++ {
require.NoError(t, db.Create(&testRow{Name: fmt.Sprintf("row_%d", i)}).Error)
}
}

View File

@@ -0,0 +1,47 @@
package database
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestClassifyCompletion_CompletedAfterDue(t *testing.T) {
dueDate := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC)
completedAt := time.Date(2025, 6, 5, 14, 30, 0, 0, time.UTC) // 4 days after due
result := classifyCompletion(completedAt, dueDate, 30)
assert.Equal(t, "overdue_tasks", result)
}
func TestClassifyCompletion_CompletedOnDueDate(t *testing.T) {
dueDate := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
completedAt := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) // same day
result := classifyCompletion(completedAt, dueDate, 30)
// Completed on the due date: daysBefore == 0, which is <= threshold → due_soon_tasks
assert.Equal(t, "due_soon_tasks", result)
}
func TestClassifyCompletion_CompletedWithinThreshold(t *testing.T) {
dueDate := time.Date(2025, 7, 1, 0, 0, 0, 0, time.UTC)
completedAt := time.Date(2025, 6, 10, 8, 0, 0, 0, time.UTC) // 21 days before due
result := classifyCompletion(completedAt, dueDate, 30)
// 21 days before due, within 30-day threshold → due_soon_tasks
assert.Equal(t, "due_soon_tasks", result)
}
func TestClassifyCompletion_CompletedBeyondThreshold(t *testing.T) {
dueDate := time.Date(2025, 9, 1, 0, 0, 0, 0, time.UTC)
completedAt := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) // 92 days before due
result := classifyCompletion(completedAt, dueDate, 30)
// 92 days before due, beyond 30-day threshold → upcoming_tasks
assert.Equal(t, "upcoming_tasks", result)
}

View File

@@ -0,0 +1,31 @@
package database
import "sort"
// sortMigrationNames returns a sorted copy of the names slice.
func sortMigrationNames(names []string) []string {
sorted := make([]string, len(names))
copy(sorted, names)
sort.Strings(sorted)
return sorted
}
// buildAppliedSet converts a list of applied migrations to a lookup set.
func buildAppliedSet(applied []DataMigration) map[string]bool {
set := make(map[string]bool, len(applied))
for _, m := range applied {
set[m.Name] = true
}
return set
}
// filterPending returns names not present in the applied set.
func filterPending(names []string, applied map[string]bool) []string {
var pending []string
for _, name := range names {
if !applied[name] {
pending = append(pending, name)
}
}
return pending
}

View File

@@ -0,0 +1,82 @@
package database
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// --- sortMigrationNames ---
func TestSortMigrationNames_Alphabetical(t *testing.T) {
input := []string{"charlie", "alpha", "bravo"}
result := sortMigrationNames(input)
assert.Equal(t, []string{"alpha", "bravo", "charlie"}, result)
// Verify original slice is not mutated
assert.Equal(t, []string{"charlie", "alpha", "bravo"}, input)
}
func TestSortMigrationNames_Empty(t *testing.T) {
result := sortMigrationNames([]string{})
assert.Equal(t, []string{}, result)
assert.Len(t, result, 0)
}
// --- buildAppliedSet ---
func TestBuildAppliedSet_Multiple(t *testing.T) {
applied := []DataMigration{
{ID: 1, Name: "20250101_first", AppliedAt: time.Now()},
{ID: 2, Name: "20250201_second", AppliedAt: time.Now()},
{ID: 3, Name: "20250301_third", AppliedAt: time.Now()},
}
set := buildAppliedSet(applied)
assert.Len(t, set, 3)
assert.True(t, set["20250101_first"])
assert.True(t, set["20250201_second"])
assert.True(t, set["20250301_third"])
assert.False(t, set["nonexistent"])
}
func TestBuildAppliedSet_Empty(t *testing.T) {
set := buildAppliedSet([]DataMigration{})
assert.Len(t, set, 0)
}
// --- filterPending ---
func TestFilterPending_SomePending(t *testing.T) {
names := []string{"20250101_first", "20250201_second", "20250301_third"}
applied := map[string]bool{
"20250101_first": true,
}
pending := filterPending(names, applied)
assert.Equal(t, []string{"20250201_second", "20250301_third"}, pending)
}
func TestFilterPending_AllApplied(t *testing.T) {
names := []string{"20250101_first", "20250201_second"}
applied := map[string]bool{
"20250101_first": true,
"20250201_second": true,
}
pending := filterPending(names, applied)
assert.Nil(t, pending)
}
func TestFilterPending_NoneApplied(t *testing.T) {
names := []string{"20250101_first", "20250201_second", "20250301_third"}
applied := map[string]bool{}
pending := filterPending(names, applied)
assert.Equal(t, []string{"20250101_first", "20250201_second", "20250301_third"}, pending)
}

View File

@@ -0,0 +1,130 @@
package requests
import (
"encoding/json"
"testing"
"time"
)
func TestFlexibleDate_UnmarshalJSON_DateOnly(t *testing.T) {
var fd FlexibleDate
err := fd.UnmarshalJSON([]byte(`"2025-11-27"`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := time.Date(2025, 11, 27, 0, 0, 0, 0, time.UTC)
if !fd.Time.Equal(want) {
t.Errorf("got %v, want %v", fd.Time, want)
}
}
func TestFlexibleDate_UnmarshalJSON_RFC3339(t *testing.T) {
var fd FlexibleDate
err := fd.UnmarshalJSON([]byte(`"2025-11-27T15:30:00Z"`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := time.Date(2025, 11, 27, 15, 30, 0, 0, time.UTC)
if !fd.Time.Equal(want) {
t.Errorf("got %v, want %v", fd.Time, want)
}
}
func TestFlexibleDate_UnmarshalJSON_Null(t *testing.T) {
var fd FlexibleDate
err := fd.UnmarshalJSON([]byte(`null`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !fd.Time.IsZero() {
t.Errorf("expected zero time, got %v", fd.Time)
}
}
func TestFlexibleDate_UnmarshalJSON_EmptyString(t *testing.T) {
var fd FlexibleDate
err := fd.UnmarshalJSON([]byte(`""`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !fd.Time.IsZero() {
t.Errorf("expected zero time, got %v", fd.Time)
}
}
func TestFlexibleDate_UnmarshalJSON_Invalid(t *testing.T) {
var fd FlexibleDate
err := fd.UnmarshalJSON([]byte(`"not-a-date"`))
if err == nil {
t.Fatal("expected error for invalid date, got nil")
}
}
func TestFlexibleDate_MarshalJSON_Valid(t *testing.T) {
fd := FlexibleDate{Time: time.Date(2025, 11, 27, 15, 30, 0, 0, time.UTC)}
data, err := fd.MarshalJSON()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var s string
if err := json.Unmarshal(data, &s); err != nil {
t.Fatalf("result is not a JSON string: %v", err)
}
want := "2025-11-27T15:30:00Z"
if s != want {
t.Errorf("got %q, want %q", s, want)
}
}
func TestFlexibleDate_MarshalJSON_Zero(t *testing.T) {
fd := FlexibleDate{}
data, err := fd.MarshalJSON()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(data) != "null" {
t.Errorf("got %s, want null", string(data))
}
}
func TestFlexibleDate_ToTimePtr_Valid(t *testing.T) {
fd := &FlexibleDate{Time: time.Date(2025, 11, 27, 0, 0, 0, 0, time.UTC)}
ptr := fd.ToTimePtr()
if ptr == nil {
t.Fatal("expected non-nil pointer")
}
if !ptr.Equal(fd.Time) {
t.Errorf("got %v, want %v", *ptr, fd.Time)
}
}
func TestFlexibleDate_ToTimePtr_Zero(t *testing.T) {
fd := &FlexibleDate{}
ptr := fd.ToTimePtr()
if ptr != nil {
t.Errorf("expected nil, got %v", *ptr)
}
}
func TestFlexibleDate_ToTimePtr_NilReceiver(t *testing.T) {
var fd *FlexibleDate
ptr := fd.ToTimePtr()
if ptr != nil {
t.Errorf("expected nil for nil receiver, got %v", *ptr)
}
}
func TestFlexibleDate_RoundTrip(t *testing.T) {
original := FlexibleDate{Time: time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)}
data, err := original.MarshalJSON()
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var restored FlexibleDate
if err := restored.UnmarshalJSON(data); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if !original.Time.Equal(restored.Time) {
t.Errorf("round-trip mismatch: original %v, restored %v", original.Time, restored.Time)
}
}

View File

@@ -0,0 +1,833 @@
package responses
import (
"fmt"
"testing"
"time"
"github.com/shopspring/decimal"
"github.com/treytartt/honeydue-api/internal/models"
)
// --- helpers ---
func timePtr(t time.Time) *time.Time { return &t }
func uintPtr(v uint) *uint { return &v }
func intPtr(v int) *int { return &v }
func strPtr(v string) *string { return &v }
func float64Ptr(v float64) *float64 { return &v }
var fixedNow = time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
func makeUser() *models.User {
return &models.User{
ID: 1,
Username: "john",
Email: "john@example.com",
FirstName: "John",
LastName: "Doe",
IsActive: true,
DateJoined: fixedNow,
LastLogin: timePtr(fixedNow),
Profile: &models.UserProfile{
BaseModel: models.BaseModel{ID: 10},
UserID: 1,
Verified: true,
Bio: "hello",
},
}
}
func makeUserNoProfile() *models.User {
u := makeUser()
u.Profile = nil
return u
}
// ==================== auth.go ====================
func TestNewUserResponse_AllFields(t *testing.T) {
u := makeUser()
resp := NewUserResponse(u)
if resp.ID != 1 {
t.Errorf("ID = %d, want 1", resp.ID)
}
if resp.Username != "john" {
t.Errorf("Username = %q", resp.Username)
}
if !resp.Verified {
t.Error("Verified should be true when profile is verified")
}
if resp.LastLogin == nil {
t.Error("LastLogin should not be nil")
}
}
func TestNewUserResponse_NilProfile(t *testing.T) {
u := makeUserNoProfile()
resp := NewUserResponse(u)
if resp.Verified {
t.Error("Verified should be false when profile is nil")
}
}
func TestNewUserProfileResponse_Nil(t *testing.T) {
resp := NewUserProfileResponse(nil)
if resp != nil {
t.Error("expected nil for nil profile")
}
}
func TestNewUserProfileResponse_Valid(t *testing.T) {
p := &models.UserProfile{
BaseModel: models.BaseModel{ID: 5},
UserID: 1,
Verified: true,
Bio: "bio",
}
resp := NewUserProfileResponse(p)
if resp == nil {
t.Fatal("expected non-nil")
}
if resp.ID != 5 || resp.UserID != 1 || !resp.Verified || resp.Bio != "bio" {
t.Errorf("unexpected response: %+v", resp)
}
}
func TestNewCurrentUserResponse(t *testing.T) {
u := makeUser()
resp := NewCurrentUserResponse(u, "apple")
if resp.AuthProvider != "apple" {
t.Errorf("AuthProvider = %q, want apple", resp.AuthProvider)
}
if resp.Profile == nil {
t.Error("Profile should not be nil")
}
if resp.ID != 1 {
t.Errorf("ID = %d, want 1", resp.ID)
}
}
func TestNewLoginResponse(t *testing.T) {
u := makeUser()
resp := NewLoginResponse("tok123", u)
if resp.Token != "tok123" {
t.Errorf("Token = %q", resp.Token)
}
if resp.User.ID != 1 {
t.Errorf("User.ID = %d", resp.User.ID)
}
}
func TestNewRegisterResponse(t *testing.T) {
u := makeUser()
resp := NewRegisterResponse("tok456", u)
if resp.Token != "tok456" {
t.Errorf("Token = %q", resp.Token)
}
if resp.Message == "" {
t.Error("Message should not be empty")
}
}
func TestNewAppleSignInResponse(t *testing.T) {
u := makeUser()
resp := NewAppleSignInResponse("atok", u, true)
if !resp.IsNewUser {
t.Error("IsNewUser should be true")
}
if resp.Token != "atok" {
t.Errorf("Token = %q", resp.Token)
}
}
func TestNewGoogleSignInResponse(t *testing.T) {
u := makeUser()
resp := NewGoogleSignInResponse("gtok", u, false)
if resp.IsNewUser {
t.Error("IsNewUser should be false")
}
if resp.Token != "gtok" {
t.Errorf("Token = %q", resp.Token)
}
}
// ==================== task.go ====================
func makeTask() *models.Task {
due := time.Date(2025, 7, 1, 0, 0, 0, 0, time.UTC)
catID := uint(1)
priID := uint(2)
freqID := uint(3)
return &models.Task{
BaseModel: models.BaseModel{ID: 100, CreatedAt: fixedNow, UpdatedAt: fixedNow},
ResidenceID: 10,
CreatedByID: 1,
CreatedBy: *makeUser(),
Title: "Fix roof",
Description: "Repair leak",
CategoryID: &catID,
Category: &models.TaskCategory{BaseModel: models.BaseModel{ID: catID}, Name: "Exterior", Icon: "roof", Color: "#FF0000", DisplayOrder: 1},
PriorityID: &priID,
Priority: &models.TaskPriority{BaseModel: models.BaseModel{ID: priID}, Name: "High", Level: 3, Color: "#FF0000", DisplayOrder: 1},
FrequencyID: &freqID,
Frequency: &models.TaskFrequency{BaseModel: models.BaseModel{ID: freqID}, Name: "Monthly", Days: intPtr(30), DisplayOrder: 1},
DueDate: &due,
}
}
func TestNewTaskResponse_BasicFields(t *testing.T) {
task := makeTask()
resp := NewTaskResponseWithTime(task, 30, fixedNow)
if resp.ID != 100 {
t.Errorf("ID = %d", resp.ID)
}
if resp.Title != "Fix roof" {
t.Errorf("Title = %q", resp.Title)
}
if resp.CreatedBy == nil {
t.Error("CreatedBy should not be nil")
}
if resp.Category == nil {
t.Error("Category should not be nil")
}
if resp.Priority == nil {
t.Error("Priority should not be nil")
}
if resp.Frequency == nil {
t.Error("Frequency should not be nil")
}
if resp.KanbanColumn == "" {
t.Error("KanbanColumn should not be empty")
}
}
func TestNewTaskResponse_NilAssociations(t *testing.T) {
task := &models.Task{
BaseModel: models.BaseModel{ID: 200},
ResidenceID: 10,
CreatedByID: 1,
Title: "Simple task",
}
resp := NewTaskResponseWithTime(task, 30, fixedNow)
if resp.CreatedBy != nil {
t.Error("CreatedBy should be nil when CreatedBy.ID is 0")
}
if resp.Category != nil {
t.Error("Category should be nil")
}
if resp.Priority != nil {
t.Error("Priority should be nil")
}
if resp.Frequency != nil {
t.Error("Frequency should be nil")
}
if resp.AssignedTo != nil {
t.Error("AssignedTo should be nil")
}
}
func TestNewTaskResponse_WithCompletions(t *testing.T) {
task := makeTask()
task.Completions = []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}, TaskID: 100, CompletedAt: fixedNow, CompletedByID: 1},
{BaseModel: models.BaseModel{ID: 2}, TaskID: 100, CompletedAt: fixedNow, CompletedByID: 1},
}
resp := NewTaskResponseWithTime(task, 30, fixedNow)
if resp.CompletionCount != 2 {
t.Errorf("CompletionCount = %d, want 2", resp.CompletionCount)
}
}
func TestNewTaskResponseWithTime_KanbanColumn(t *testing.T) {
task := makeTask()
// due date is July 1, now is June 15 → 16 days away → due_soon (within 30 days)
resp := NewTaskResponseWithTime(task, 30, fixedNow)
if resp.KanbanColumn == "" {
t.Error("KanbanColumn should be set")
}
}
func TestNewTaskListResponse(t *testing.T) {
tasks := []models.Task{
{BaseModel: models.BaseModel{ID: 1}, Title: "A"},
{BaseModel: models.BaseModel{ID: 2}, Title: "B"},
}
results := NewTaskListResponse(tasks)
if len(results) != 2 {
t.Errorf("len = %d, want 2", len(results))
}
}
func TestNewTaskListResponse_Empty(t *testing.T) {
results := NewTaskListResponse([]models.Task{})
if len(results) != 0 {
t.Errorf("len = %d, want 0", len(results))
}
}
func TestNewTaskCompletionResponse_WithImages(t *testing.T) {
c := &models.TaskCompletion{
BaseModel: models.BaseModel{ID: 50},
TaskID: 100,
CompletedByID: 1,
CompletedBy: *makeUser(),
CompletedAt: fixedNow,
Notes: "done",
Images: []models.TaskCompletionImage{
{BaseModel: models.BaseModel{ID: 1}, ImageURL: "http://img1.jpg", Caption: "before"},
{BaseModel: models.BaseModel{ID: 2}, ImageURL: "http://img2.jpg", Caption: "after"},
},
}
resp := NewTaskCompletionResponse(c)
if resp.CompletedBy == nil {
t.Error("CompletedBy should not be nil")
}
if len(resp.Images) != 2 {
t.Errorf("Images len = %d, want 2", len(resp.Images))
}
if resp.Images[0].MediaURL != "/api/media/completion-image/1" {
t.Errorf("MediaURL = %q", resp.Images[0].MediaURL)
}
}
func TestNewTaskCompletionResponse_EmptyImages(t *testing.T) {
c := &models.TaskCompletion{
BaseModel: models.BaseModel{ID: 51},
TaskID: 100,
CompletedByID: 1,
CompletedAt: fixedNow,
}
resp := NewTaskCompletionResponse(c)
if resp.Images == nil {
t.Error("Images should be empty slice, not nil")
}
if len(resp.Images) != 0 {
t.Errorf("Images len = %d, want 0", len(resp.Images))
}
}
func TestNewKanbanBoardResponse(t *testing.T) {
board := &models.KanbanBoard{
Columns: []models.KanbanColumn{
{
Name: "overdue",
DisplayName: "Overdue",
Color: "#FF0000",
Tasks: []models.Task{{BaseModel: models.BaseModel{ID: 1}, Title: "A"}},
Count: 1,
},
},
DaysThreshold: 30,
}
resp := NewKanbanBoardResponse(board, 10, fixedNow)
if len(resp.Columns) != 1 {
t.Fatalf("Columns len = %d", len(resp.Columns))
}
if resp.ResidenceID != "10" {
t.Errorf("ResidenceID = %q, want '10'", resp.ResidenceID)
}
if resp.Columns[0].Count != 1 {
t.Errorf("Count = %d", resp.Columns[0].Count)
}
}
func TestNewKanbanBoardResponseForAll(t *testing.T) {
board := &models.KanbanBoard{
Columns: []models.KanbanColumn{},
DaysThreshold: 30,
}
resp := NewKanbanBoardResponseForAll(board, fixedNow)
if resp.ResidenceID != "all" {
t.Errorf("ResidenceID = %q, want 'all'", resp.ResidenceID)
}
}
func TestDetermineKanbanColumn_Delegates(t *testing.T) {
task := &models.Task{
BaseModel: models.BaseModel{ID: 1},
Title: "test",
}
col := DetermineKanbanColumn(task, 30)
if col == "" {
t.Error("expected non-empty column")
}
}
func TestNewTaskCompletionWithTaskResponse(t *testing.T) {
c := &models.TaskCompletion{
BaseModel: models.BaseModel{ID: 1},
TaskID: 100,
CompletedByID: 1,
CompletedAt: fixedNow,
}
task := makeTask()
resp := NewTaskCompletionWithTaskResponseWithTime(c, task, 30, fixedNow)
if resp.Task == nil {
t.Error("Task should not be nil")
}
if resp.Task.ID != 100 {
t.Errorf("Task.ID = %d", resp.Task.ID)
}
}
func TestNewTaskCompletionWithTaskResponse_NilTask(t *testing.T) {
c := &models.TaskCompletion{
BaseModel: models.BaseModel{ID: 1},
TaskID: 100,
CompletedByID: 1,
CompletedAt: fixedNow,
}
resp := NewTaskCompletionWithTaskResponseWithTime(c, nil, 30, fixedNow)
if resp.Task != nil {
t.Error("Task should be nil")
}
}
func TestNewTaskCompletionListResponse(t *testing.T) {
completions := []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}, TaskID: 100, CompletedAt: fixedNow, CompletedByID: 1},
}
results := NewTaskCompletionListResponse(completions)
if len(results) != 1 {
t.Errorf("len = %d", len(results))
}
}
func TestNewTaskCategoryResponse_Nil(t *testing.T) {
if NewTaskCategoryResponse(nil) != nil {
t.Error("expected nil")
}
}
func TestNewTaskPriorityResponse_Nil(t *testing.T) {
if NewTaskPriorityResponse(nil) != nil {
t.Error("expected nil")
}
}
func TestNewTaskFrequencyResponse_Nil(t *testing.T) {
if NewTaskFrequencyResponse(nil) != nil {
t.Error("expected nil")
}
}
func TestNewTaskUserResponse_Nil(t *testing.T) {
if NewTaskUserResponse(nil) != nil {
t.Error("expected nil")
}
}
// ==================== contractor.go ====================
func makeContractor() *models.Contractor {
resID := uint(10)
return &models.Contractor{
BaseModel: models.BaseModel{ID: 5, CreatedAt: fixedNow, UpdatedAt: fixedNow},
ResidenceID: &resID,
CreatedByID: 1,
CreatedBy: *makeUser(),
Name: "Bob's Plumbing",
Company: "Bob Co",
Phone: "555-1234",
Email: "bob@plumb.com",
Rating: float64Ptr(4.5),
IsFavorite: true,
IsActive: true,
Specialties: []models.ContractorSpecialty{
{BaseModel: models.BaseModel{ID: 1}, Name: "Plumbing", Icon: "wrench", DisplayOrder: 1},
},
Tasks: []models.Task{{BaseModel: models.BaseModel{ID: 1}}, {BaseModel: models.BaseModel{ID: 2}}},
}
}
func TestNewContractorResponse_BasicFields(t *testing.T) {
c := makeContractor()
resp := NewContractorResponse(c)
if resp.ID != 5 {
t.Errorf("ID = %d", resp.ID)
}
if resp.Name != "Bob's Plumbing" {
t.Errorf("Name = %q", resp.Name)
}
if resp.AddedBy != 1 {
t.Errorf("AddedBy = %d, want 1", resp.AddedBy)
}
if resp.CreatedBy == nil {
t.Error("CreatedBy should not be nil")
}
if resp.TaskCount != 2 {
t.Errorf("TaskCount = %d, want 2", resp.TaskCount)
}
}
func TestNewContractorResponse_WithSpecialties(t *testing.T) {
c := makeContractor()
resp := NewContractorResponse(c)
if len(resp.Specialties) != 1 {
t.Fatalf("Specialties len = %d", len(resp.Specialties))
}
if resp.Specialties[0].Name != "Plumbing" {
t.Errorf("Specialty name = %q", resp.Specialties[0].Name)
}
}
func TestNewContractorListResponse(t *testing.T) {
contractors := []models.Contractor{*makeContractor()}
results := NewContractorListResponse(contractors)
if len(results) != 1 {
t.Errorf("len = %d", len(results))
}
}
func TestNewContractorUserResponse_Nil(t *testing.T) {
if NewContractorUserResponse(nil) != nil {
t.Error("expected nil")
}
}
func TestNewContractorSpecialtyResponse(t *testing.T) {
s := &models.ContractorSpecialty{
BaseModel: models.BaseModel{ID: 1},
Name: "Electrical",
Description: "Electrical work",
Icon: "bolt",
DisplayOrder: 2,
}
resp := NewContractorSpecialtyResponse(s)
if resp.Name != "Electrical" || resp.Icon != "bolt" {
t.Errorf("unexpected: %+v", resp)
}
}
// ==================== document.go ====================
func makeDocument() *models.Document {
price := decimal.NewFromFloat(99.99)
return &models.Document{
BaseModel: models.BaseModel{ID: 20, CreatedAt: fixedNow, UpdatedAt: fixedNow},
ResidenceID: 10,
CreatedByID: 1,
CreatedBy: *makeUser(),
Title: "Warranty",
Description: "Roof warranty",
DocumentType: "warranty",
FileName: "warranty.pdf",
FileSize: func() *int64 { v := int64(1024); return &v }(),
MimeType: "application/pdf",
PurchasePrice: &price,
IsActive: true,
Images: []models.DocumentImage{
{BaseModel: models.BaseModel{ID: 1}, ImageURL: "http://img.jpg", Caption: "page 1"},
},
}
}
func TestNewDocumentResponse_MediaURL(t *testing.T) {
d := makeDocument()
resp := NewDocumentResponse(d)
want := fmt.Sprintf("/api/media/document/%d", d.ID)
if resp.MediaURL != want {
t.Errorf("MediaURL = %q, want %q", resp.MediaURL, want)
}
if resp.Residence != resp.ResidenceID {
t.Error("Residence alias should equal ResidenceID")
}
}
func TestNewDocumentResponse_WithImages(t *testing.T) {
d := makeDocument()
resp := NewDocumentResponse(d)
if len(resp.Images) != 1 {
t.Fatalf("Images len = %d", len(resp.Images))
}
if resp.Images[0].MediaURL != "/api/media/document-image/1" {
t.Errorf("Image MediaURL = %q", resp.Images[0].MediaURL)
}
}
func TestNewDocumentResponse_EmptyImageURL(t *testing.T) {
d := makeDocument()
d.Images = []models.DocumentImage{
{BaseModel: models.BaseModel{ID: 5}, ImageURL: "", Caption: "missing"},
}
resp := NewDocumentResponse(d)
if resp.Images[0].Error != "image source URL is missing" {
t.Errorf("Error = %q", resp.Images[0].Error)
}
}
func TestNewDocumentListResponse(t *testing.T) {
docs := []models.Document{*makeDocument()}
results := NewDocumentListResponse(docs)
if len(results) != 1 {
t.Errorf("len = %d", len(results))
}
}
func TestNewDocumentUserResponse_Nil(t *testing.T) {
if NewDocumentUserResponse(nil) != nil {
t.Error("expected nil")
}
}
// ==================== residence.go ====================
func makeResidence() *models.Residence {
propTypeID := uint(1)
return &models.Residence{
BaseModel: models.BaseModel{ID: 10, CreatedAt: fixedNow, UpdatedAt: fixedNow},
OwnerID: 1,
Owner: *makeUser(),
Name: "My House",
PropertyTypeID: &propTypeID,
PropertyType: &models.ResidenceType{BaseModel: models.BaseModel{ID: 1}, Name: "House"},
StreetAddress: "123 Main St",
City: "Springfield",
StateProvince: "IL",
PostalCode: "62701",
Country: "USA",
Bedrooms: intPtr(3),
IsPrimary: true,
IsActive: true,
HasPool: true,
HeatingType: strPtr("central"),
Users: []models.User{
{ID: 1, Username: "john", Email: "john@example.com"},
{ID: 2, Username: "jane", Email: "jane@example.com"},
},
}
}
func TestNewResidenceResponse_AllFields(t *testing.T) {
r := makeResidence()
resp := NewResidenceResponse(r)
if resp.ID != 10 {
t.Errorf("ID = %d", resp.ID)
}
if resp.Name != "My House" {
t.Errorf("Name = %q", resp.Name)
}
if resp.Owner == nil {
t.Error("Owner should not be nil")
}
if resp.PropertyType == nil {
t.Error("PropertyType should not be nil")
}
if !resp.HasPool {
t.Error("HasPool should be true")
}
if resp.HeatingType == nil || *resp.HeatingType != "central" {
t.Error("HeatingType should be 'central'")
}
}
func TestNewResidenceResponse_WithUsers(t *testing.T) {
r := makeResidence()
resp := NewResidenceResponse(r)
if len(resp.Users) != 2 {
t.Errorf("Users len = %d, want 2", len(resp.Users))
}
}
func TestNewResidenceResponse_NoUsers(t *testing.T) {
r := makeResidence()
r.Users = nil
resp := NewResidenceResponse(r)
if resp.Users == nil {
t.Error("Users should be empty slice, not nil")
}
if len(resp.Users) != 0 {
t.Errorf("Users len = %d, want 0", len(resp.Users))
}
}
func TestNewResidenceListResponse(t *testing.T) {
residences := []models.Residence{*makeResidence()}
results := NewResidenceListResponse(residences)
if len(results) != 1 {
t.Errorf("len = %d", len(results))
}
}
func TestNewResidenceUserResponse_Nil(t *testing.T) {
if NewResidenceUserResponse(nil) != nil {
t.Error("expected nil")
}
}
func TestNewResidenceTypeResponse_Nil(t *testing.T) {
if NewResidenceTypeResponse(nil) != nil {
t.Error("expected nil")
}
}
func TestNewShareCodeResponse(t *testing.T) {
sc := &models.ResidenceShareCode{
BaseModel: models.BaseModel{ID: 1, CreatedAt: fixedNow},
Code: "ABC123",
ResidenceID: 10,
CreatedByID: 1,
IsActive: true,
ExpiresAt: timePtr(fixedNow.Add(24 * time.Hour)),
}
resp := NewShareCodeResponse(sc)
if resp.Code != "ABC123" {
t.Errorf("Code = %q", resp.Code)
}
if resp.ResidenceID != 10 {
t.Errorf("ResidenceID = %d", resp.ResidenceID)
}
}
// ==================== task_template.go ====================
func TestParseTags_Empty(t *testing.T) {
result := parseTags("")
if len(result) != 0 {
t.Errorf("len = %d, want 0", len(result))
}
}
func TestParseTags_Multiple(t *testing.T) {
result := parseTags("plumbing,electrical,roofing")
if len(result) != 3 {
t.Errorf("len = %d, want 3", len(result))
}
if result[0] != "plumbing" || result[1] != "electrical" || result[2] != "roofing" {
t.Errorf("unexpected tags: %v", result)
}
}
func TestParseTags_Whitespace(t *testing.T) {
result := parseTags(" plumbing , , electrical ")
if len(result) != 2 {
t.Errorf("len = %d, want 2 (should skip empty after trim)", len(result))
}
if result[0] != "plumbing" || result[1] != "electrical" {
t.Errorf("unexpected tags: %v", result)
}
}
func makeTemplate(catID *uint, cat *models.TaskCategory) models.TaskTemplate {
return models.TaskTemplate{
BaseModel: models.BaseModel{ID: 1, CreatedAt: fixedNow, UpdatedAt: fixedNow},
Title: "Clean Gutters",
Description: "Remove debris",
CategoryID: catID,
Category: cat,
IconIOS: "leaf",
IconAndroid: "leaf_android",
Tags: "exterior,seasonal",
DisplayOrder: 1,
IsActive: true,
}
}
func TestNewTaskTemplateResponse(t *testing.T) {
catID := uint(1)
cat := &models.TaskCategory{BaseModel: models.BaseModel{ID: 1}, Name: "Exterior"}
tmpl := makeTemplate(&catID, cat)
resp := NewTaskTemplateResponse(&tmpl)
if resp.Title != "Clean Gutters" {
t.Errorf("Title = %q", resp.Title)
}
if len(resp.Tags) != 2 {
t.Errorf("Tags len = %d", len(resp.Tags))
}
if resp.Category == nil {
t.Error("Category should not be nil")
}
}
func TestNewTaskTemplateResponse_WithRegion(t *testing.T) {
tmpl := makeTemplate(nil, nil)
tmpl.Regions = []models.ClimateRegion{
{BaseModel: models.BaseModel{ID: 5}, Name: "Southeast"},
}
resp := NewTaskTemplateResponse(&tmpl)
if resp.RegionID == nil || *resp.RegionID != 5 {
t.Error("RegionID should be 5")
}
if resp.RegionName != "Southeast" {
t.Errorf("RegionName = %q", resp.RegionName)
}
}
func TestNewTaskTemplatesGroupedResponse_Grouping(t *testing.T) {
catID := uint(1)
cat := &models.TaskCategory{BaseModel: models.BaseModel{ID: 1}, Name: "Exterior"}
templates := []models.TaskTemplate{
makeTemplate(&catID, cat),
makeTemplate(&catID, cat),
}
resp := NewTaskTemplatesGroupedResponse(templates)
if len(resp.Categories) != 1 {
t.Fatalf("Categories len = %d, want 1", len(resp.Categories))
}
if resp.Categories[0].CategoryName != "Exterior" {
t.Errorf("CategoryName = %q", resp.Categories[0].CategoryName)
}
if resp.Categories[0].Count != 2 {
t.Errorf("Count = %d, want 2", resp.Categories[0].Count)
}
if resp.TotalCount != 2 {
t.Errorf("TotalCount = %d, want 2", resp.TotalCount)
}
}
func TestNewTaskTemplatesGroupedResponse_Uncategorized(t *testing.T) {
tmpl := makeTemplate(nil, nil)
resp := NewTaskTemplatesGroupedResponse([]models.TaskTemplate{tmpl})
if len(resp.Categories) != 1 {
t.Fatalf("Categories len = %d", len(resp.Categories))
}
if resp.Categories[0].CategoryName != "Uncategorized" {
t.Errorf("CategoryName = %q", resp.Categories[0].CategoryName)
}
}
func TestNewTaskTemplateListResponse(t *testing.T) {
templates := []models.TaskTemplate{makeTemplate(nil, nil)}
results := NewTaskTemplateListResponse(templates)
if len(results) != 1 {
t.Errorf("len = %d", len(results))
}
}
// ==================== DetermineKanbanColumnWithTime ====================
func TestDetermineKanbanColumnWithTime(t *testing.T) {
task := makeTask()
col := DetermineKanbanColumnWithTime(task, 30, fixedNow)
if col == "" {
t.Error("expected non-empty column")
}
}
// ==================== NewTaskResponse uses NewTaskResponseWithThreshold ====================
func TestNewTaskResponse_UsesDefault30(t *testing.T) {
task := makeTask()
resp := NewTaskResponse(task)
if resp.ID != 100 {
t.Errorf("ID = %d", resp.ID)
}
// Just verify it doesn't panic and produces a response
}
// ==================== NewTaskCompletionWithTaskResponse UTC variant ====================
func TestNewTaskCompletionWithTaskResponse_UTC(t *testing.T) {
c := &models.TaskCompletion{
BaseModel: models.BaseModel{ID: 1},
TaskID: 100,
CompletedByID: 1,
CompletedAt: fixedNow,
}
task := makeTask()
resp := NewTaskCompletionWithTaskResponse(c, task, 30)
if resp.Task == nil {
t.Error("Task should not be nil")
}
}

View File

@@ -0,0 +1,105 @@
package echohelpers
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDefaultQuery(t *testing.T) {
tests := []struct {
name string
query string
key string
defaultValue string
expected string
}{
{"returns value when present", "/?status=active", "status", "all", "active"},
{"returns default when absent", "/", "status", "all", "all"},
{"returns default for empty value", "/?status=", "status", "all", "all"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, tc.query, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := DefaultQuery(c, tc.key, tc.defaultValue)
assert.Equal(t, tc.expected, result)
})
}
}
func TestParseUintParam(t *testing.T) {
tests := []struct {
name string
paramValue string
expected uint
expectError bool
}{
{"valid uint", "42", 42, false},
{"zero", "0", 0, false},
{"invalid string", "abc", 0, true},
{"negative", "-1", 0, true},
{"empty", "", 0, true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetParamNames("id")
c.SetParamValues(tc.paramValue)
result, err := ParseUintParam(c, "id")
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func TestParseIntParam(t *testing.T) {
tests := []struct {
name string
paramValue string
expected int
expectError bool
}{
{"valid int", "42", 42, false},
{"zero", "0", 0, false},
{"negative", "-5", -5, false},
{"invalid string", "abc", 0, true},
{"empty", "", 0, true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetParamNames("id")
c.SetParamValues(tc.paramValue)
result, err := ParseIntParam(c, "id")
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}

View File

@@ -38,7 +38,7 @@ func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB
func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "password123")
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "Password123")
// Create profile for the user
profile := &models.UserProfile{UserID: user.ID, Verified: true}
@@ -52,7 +52,7 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
authGroup.DELETE("/account/", handler.DeleteAccount)
t.Run("successful deletion with correct password", func(t *testing.T) {
password := "password123"
password := "Password123"
req := map[string]interface{}{
"password": password,
}
@@ -84,7 +84,7 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "password123")
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "Password123")
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
@@ -105,7 +105,7 @@ func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) {
func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "password123")
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "Password123")
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
@@ -207,7 +207,7 @@ func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
t.Run("unauthenticated request returns 401", func(t *testing.T) {
req := map[string]interface{}{
"password": "password123",
"password": "Password123",
}
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "")

View File

@@ -43,7 +43,7 @@ func TestAuthHandler_Register(t *testing.T) {
req := requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "password123",
Password: "Password123",
FirstName: "New",
LastName: "User",
}
@@ -98,7 +98,7 @@ func TestAuthHandler_Register(t *testing.T) {
req := requests.RegisterRequest{
Username: "duplicate",
Email: "unique1@test.com",
Password: "password123",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
@@ -117,7 +117,7 @@ func TestAuthHandler_Register(t *testing.T) {
req := requests.RegisterRequest{
Username: "user1",
Email: "duplicate@test.com",
Password: "password123",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
@@ -142,7 +142,7 @@ func TestAuthHandler_Login(t *testing.T) {
registerReq := requests.RegisterRequest{
Username: "logintest",
Email: "login@test.com",
Password: "password123",
Password: "Password123",
FirstName: "Test",
LastName: "User",
}
@@ -152,7 +152,7 @@ func TestAuthHandler_Login(t *testing.T) {
t.Run("successful login with username", func(t *testing.T) {
req := requests.LoginRequest{
Username: "logintest",
Password: "password123",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
@@ -174,7 +174,7 @@ func TestAuthHandler_Login(t *testing.T) {
t.Run("successful login with email", func(t *testing.T) {
req := requests.LoginRequest{
Username: "login@test.com", // Using email as username
Password: "password123",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
@@ -199,7 +199,7 @@ func TestAuthHandler_Login(t *testing.T) {
t.Run("login with non-existent user", func(t *testing.T) {
req := requests.LoginRequest{
Username: "nonexistent",
Password: "password123",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
@@ -223,7 +223,7 @@ func TestAuthHandler_CurrentUser(t *testing.T) {
handler, e, userRepo := setupAuthHandler(t)
db := testutil.SetupTestDB(t)
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "password123")
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "Password123")
user.FirstName = "Test"
user.LastName = "User"
userRepo.Update(user)
@@ -251,7 +251,7 @@ func TestAuthHandler_UpdateProfile(t *testing.T) {
handler, e, userRepo := setupAuthHandler(t)
db := testutil.SetupTestDB(t)
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "password123")
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "Password123")
userRepo.Update(user)
authGroup := e.Group("/api/auth")
@@ -289,7 +289,7 @@ func TestAuthHandler_ForgotPassword(t *testing.T) {
registerReq := requests.RegisterRequest{
Username: "forgottest",
Email: "forgot@test.com",
Password: "password123",
Password: "Password123",
}
testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "")
@@ -323,7 +323,7 @@ func TestAuthHandler_Logout(t *testing.T) {
handler, e, userRepo := setupAuthHandler(t)
db := testutil.SetupTestDB(t)
user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "password123")
user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "Password123")
userRepo.Update(user)
authGroup := e.Group("/api/auth")
@@ -350,7 +350,7 @@ func TestAuthHandler_JSONResponses(t *testing.T) {
req := requests.RegisterRequest{
Username: "jsontest",
Email: "json@test.com",
Password: "password123",
Password: "Password123",
FirstName: "JSON",
LastName: "Test",
}

View File

@@ -2,6 +2,7 @@ package handlers
import (
"encoding/json"
"fmt"
"net/http"
"testing"
@@ -180,3 +181,284 @@ func TestContractorHandler_CreateContractor_100Specialties_Returns400(t *testing
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestContractorHandler_ListContractors(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Electrician Bob")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", handler.ListContractors)
t.Run("successful list", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/contractors/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response []map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Len(t, response, 2)
})
t.Run("user with no contractors returns empty", func(t *testing.T) {
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
e2 := testutil.SetupTestRouter()
authGroup2 := e2.Group("/api/contractors")
authGroup2.Use(testutil.MockAuthMiddleware(otherUser))
authGroup2.GET("/", handler.ListContractors)
w := testutil.MakeRequest(e2, "GET", "/api/contractors/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response []map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Len(t, response, 0)
})
}
func TestContractorHandler_GetContractor(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/:id/", handler.GetContractor)
t.Run("successful get", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/contractors/%d/", contractor.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "Plumber Joe", response["name"])
})
t.Run("not found returns 404", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/contractors/99999/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/contractors/invalid/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestContractorHandler_UpdateContractor(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.PUT("/:id/", handler.UpdateContractor)
t.Run("successful update", func(t *testing.T) {
newName := "Plumber Joe Updated"
req := requests.UpdateContractorRequest{
Name: &newName,
}
w := testutil.MakeRequest(e, "PUT", fmt.Sprintf("/api/contractors/%d/", contractor.ID), req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "Plumber Joe Updated", response["name"])
})
t.Run("invalid id returns 400", func(t *testing.T) {
newName := "Updated"
req := requests.UpdateContractorRequest{Name: &newName}
w := testutil.MakeRequest(e, "PUT", "/api/contractors/invalid/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("not found returns 404", func(t *testing.T) {
newName := "Updated"
req := requests.UpdateContractorRequest{Name: &newName}
w := testutil.MakeRequest(e, "PUT", "/api/contractors/99999/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
}
func TestContractorHandler_DeleteContractor(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/:id/", handler.DeleteContractor)
t.Run("successful delete", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", fmt.Sprintf("/api/contractors/%d/", contractor.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "message")
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/contractors/invalid/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("not found returns 404", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/contractors/99999/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
}
func TestContractorHandler_ToggleFavorite(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/:id/toggle-favorite/", handler.ToggleFavorite)
t.Run("toggle favorite on", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/contractors/%d/toggle-favorite/", contractor.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "is_favorite")
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/contractors/invalid/toggle-favorite/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("not found returns 404", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/contractors/99999/toggle-favorite/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
}
func TestContractorHandler_ListContractorsByResidence(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/by-residence/:residence_id/", handler.ListContractorsByResidence)
t.Run("successful list by residence", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/contractors/by-residence/%d/", residence.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response []map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Len(t, response, 1)
})
t.Run("invalid residence id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/contractors/by-residence/invalid/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestContractorHandler_GetSpecialties(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/specialties/", handler.GetSpecialties)
t.Run("successful list specialties", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/contractors/specialties/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response []map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Greater(t, len(response), 0)
})
}
func TestContractorHandler_GetContractorTasks(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/:id/tasks/", handler.GetContractorTasks)
t.Run("successful get tasks", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/contractors/%d/tasks/", contractor.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/contractors/invalid/tasks/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestContractorHandler_CreateContractor_WithOptionalFields(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateContractor)
t.Run("creation with all optional fields", func(t *testing.T) {
rating := 4.5
isFavorite := true
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Full Contractor",
Company: "ABC Plumbing",
Phone: "555-1234",
Email: "contractor@test.com",
Notes: "Great work",
Rating: &rating,
IsFavorite: &isFavorite,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "Full Contractor", response["name"])
assert.Equal(t, "ABC Plumbing", response["company"])
})
}

View File

@@ -224,3 +224,235 @@ func TestDocumentHandler_DeleteDocument(t *testing.T) {
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
}
func TestDocumentHandler_UpdateDocument(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original Title")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.PUT("/:id/", handler.UpdateDocument)
t.Run("successful update", func(t *testing.T) {
newTitle := "Updated Title"
req := map[string]interface{}{
"title": newTitle,
}
w := testutil.MakeRequest(e, "PUT", fmt.Sprintf("/api/documents/%d/", doc.ID), req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "Updated Title", response["title"])
})
t.Run("invalid id returns 400", func(t *testing.T) {
req := map[string]interface{}{"title": "Updated"}
w := testutil.MakeRequest(e, "PUT", "/api/documents/invalid/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("not found returns 404", func(t *testing.T) {
req := map[string]interface{}{"title": "Updated"}
w := testutil.MakeRequest(e, "PUT", "/api/documents/99999/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
t.Run("access denied for other user", func(t *testing.T) {
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
e2 := testutil.SetupTestRouter()
otherGroup := e2.Group("/api/documents")
otherGroup.Use(testutil.MockAuthMiddleware(otherUser))
otherGroup.PUT("/:id/", handler.UpdateDocument)
req := map[string]interface{}{"title": "Hacked"}
w := testutil.MakeRequest(e2, "PUT", fmt.Sprintf("/api/documents/%d/", doc.ID), req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusForbidden)
})
}
func TestDocumentHandler_ListDocuments_Filters(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", handler.ListDocuments)
t.Run("filter by residence", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/documents/?residence=%d", residence.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response []map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Len(t, response, 1)
})
t.Run("filter by search", func(t *testing.T) {
t.Skip("ILIKE is not supported in SQLite; search filter requires PostgreSQL")
})
t.Run("expiring_soon out of range returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/documents/?expiring_soon=5000", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestDocumentHandler_ListWarranties(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a warranty document
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Warranty Doc")
require.NoError(t, db.Model(doc).Update("document_type", "warranty").Error)
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/warranties/", handler.ListWarranties)
t.Run("successful list warranties", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/documents/warranties/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
})
}
func TestDocumentHandler_ActivateDeactivateDocument(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Toggle Doc")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/:id/deactivate/", handler.DeactivateDocument)
authGroup.POST("/:id/activate/", handler.ActivateDocument)
t.Run("deactivate document", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/documents/%d/deactivate/", doc.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, false, response["is_active"])
})
t.Run("activate document", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/documents/%d/activate/", doc.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, true, response["is_active"])
})
t.Run("activate invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/documents/invalid/activate/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("deactivate invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/documents/invalid/deactivate/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("activate not found returns 404", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/documents/99999/activate/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
t.Run("deactivate not found returns 404", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/documents/99999/deactivate/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
}
func TestDocumentHandler_CreateDocument_ValidationErrors(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateDocument)
t.Run("missing title returns 400", func(t *testing.T) {
body := map[string]interface{}{
"residence_id": residence.ID,
}
w := testutil.MakeRequest(e, "POST", "/api/documents/", body, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("missing residence_id returns 400", func(t *testing.T) {
body := map[string]interface{}{
"title": "Test Doc",
}
w := testutil.MakeRequest(e, "POST", "/api/documents/", body, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("invalid document_type returns 400", func(t *testing.T) {
body := map[string]interface{}{
"title": "Test Doc",
"residence_id": residence.ID,
"document_type": "invalid_type",
}
w := testutil.MakeRequest(e, "POST", "/api/documents/", body, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestDocumentHandler_GetDocument_InvalidID(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/:id/", handler.GetDocument)
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/documents/invalid/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestDocumentHandler_DeleteDocument_InvalidID(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/:id/", handler.DeleteDocument)
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/documents/invalid/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestDocumentHandler_DeleteDocument_AccessDenied(t *testing.T) {
handler, e, db := setupDocumentHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
authGroup := e.Group("/api/documents")
authGroup.Use(testutil.MockAuthMiddleware(otherUser))
authGroup.DELETE("/:id/", handler.DeleteDocument)
t.Run("access denied for other user", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", fmt.Sprintf("/api/documents/%d/", doc.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusForbidden)
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -86,3 +86,323 @@ func TestNotificationHandler_ListNotifications_LimitCappedAt200(t *testing.T) {
assert.Equal(t, 50, count, "response should use default limit of 50")
})
}
func TestNotificationHandler_ListNotifications_Pagination(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
createTestNotifications(t, db, user.ID, 20)
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", handler.ListNotifications)
t.Run("offset skips notifications", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=5&offset=15", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 5, count, "should return remaining 5 after offset 15")
})
t.Run("response has results array", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=3", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "results")
assert.Contains(t, response, "count")
results := response["results"].([]interface{})
assert.Len(t, results, 3)
})
t.Run("negative limit ignored", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=-5", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Negative limit should default to 50 (since -5 > 0 is false)
count := int(response["count"].(float64))
assert.Equal(t, 20, count, "should return all 20 with default limit of 50")
})
}
func TestNotificationHandler_GetUnreadCount(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create some unread notifications
createTestNotifications(t, db, user.ID, 5)
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/unread-count/", handler.GetUnreadCount)
t.Run("successful unread count", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/unread-count/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "unread_count")
unreadCount := int(response["unread_count"].(float64))
assert.Equal(t, 5, unreadCount)
})
t.Run("user with no notifications returns zero", func(t *testing.T) {
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
e2 := testutil.SetupTestRouter()
authGroup2 := e2.Group("/api/notifications")
authGroup2.Use(testutil.MockAuthMiddleware(otherUser))
authGroup2.GET("/unread-count/", handler.GetUnreadCount)
w := testutil.MakeRequest(e2, "GET", "/api/notifications/unread-count/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, float64(0), response["unread_count"])
})
}
func TestNotificationHandler_MarkAsRead(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create a notification
notif := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Test Notification",
Body: "Test Body",
}
require.NoError(t, db.Create(notif).Error)
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/:id/read/", handler.MarkAsRead)
t.Run("successful mark as read", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/notifications/%d/read/", notif.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "message")
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/notifications/invalid/read/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("not found returns 404", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/notifications/99999/read/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusNotFound)
})
}
func TestNotificationHandler_MarkAllAsRead(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
createTestNotifications(t, db, user.ID, 5)
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/mark-all-read/", handler.MarkAllAsRead)
authGroup.GET("/unread-count/", handler.GetUnreadCount)
t.Run("successful mark all as read", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/notifications/mark-all-read/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "message")
})
t.Run("unread count is zero after mark all", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/unread-count/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, float64(0), response["unread_count"])
})
}
func TestNotificationHandler_GetPreferences(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/preferences/", handler.GetPreferences)
t.Run("successful get preferences", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/preferences/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Default preferences should have standard fields
assert.Contains(t, response, "task_due_soon")
assert.Contains(t, response, "task_overdue")
assert.Contains(t, response, "task_completed")
})
}
func TestNotificationHandler_UpdatePreferences(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.PUT("/preferences/", handler.UpdatePreferences)
t.Run("successful update preferences", func(t *testing.T) {
req := map[string]interface{}{
"task_due_soon": false,
"task_overdue": true,
}
w := testutil.MakeRequest(e, "PUT", "/api/notifications/preferences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, false, response["task_due_soon"])
assert.Equal(t, true, response["task_overdue"])
})
}
func TestNotificationHandler_RegisterDevice(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/devices/", handler.RegisterDevice)
t.Run("successful device registration", func(t *testing.T) {
req := map[string]interface{}{
"name": "iPhone 15",
"device_id": "test-device-id-123",
"registration_id": "test-registration-id-abc",
"platform": "ios",
}
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
t.Run("missing required fields returns 400", func(t *testing.T) {
req := map[string]interface{}{
"name": "iPhone 15",
// Missing device_id, registration_id, platform
}
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("invalid platform returns 400", func(t *testing.T) {
req := map[string]interface{}{
"device_id": "test-device-id-456",
"registration_id": "test-registration-id-def",
"platform": "windows", // invalid
}
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestNotificationHandler_ListDevices(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/devices/", handler.ListDevices)
t.Run("successful list devices", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/devices/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
})
}
func TestNotificationHandler_UnregisterDevice(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/devices/unregister/", handler.UnregisterDevice)
t.Run("missing registration_id returns 400", func(t *testing.T) {
req := map[string]interface{}{
"platform": "ios",
}
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/unregister/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("missing platform returns 400", func(t *testing.T) {
req := map[string]interface{}{
"registration_id": "test-id",
}
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/unregister/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("invalid platform returns 400", func(t *testing.T) {
req := map[string]interface{}{
"registration_id": "test-id",
"platform": "windows",
}
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/unregister/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestNotificationHandler_DeleteDevice(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/devices/:id/", handler.DeleteDevice)
t.Run("missing platform query param returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/notifications/devices/1/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("invalid platform returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/notifications/devices/1/?platform=windows", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/notifications/devices/invalid/?platform=ios", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}

View File

@@ -567,3 +567,164 @@ func TestResidenceHandler_CreateResidence_NegativeBedrooms_Returns400(t *testing
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}
func TestResidenceHandler_GetMyResidences(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
testutil.CreateTestResidence(t, db, user.ID, "House 1")
testutil.CreateTestResidence(t, db, user.ID, "House 2")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/my-residences/", handler.GetMyResidences)
t.Run("successful my residences", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/residences/my-residences/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
// GetMyResidences returns MyResidencesResponse: {"residences": [...]}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
residences := response["residences"].([]interface{})
assert.Len(t, residences, 2)
})
t.Run("user with no residences returns empty", func(t *testing.T) {
noResUser := testutil.CreateTestUser(t, db, "nores", "nores@test.com", "Password123")
e2 := testutil.SetupTestRouter()
authGroup2 := e2.Group("/api/residences")
authGroup2.Use(testutil.MockAuthMiddleware(noResUser))
authGroup2.GET("/my-residences/", handler.GetMyResidences)
w := testutil.MakeRequest(e2, "GET", "/api/residences/my-residences/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
// GetMyResidences returns MyResidencesResponse: {"residences": [...] or null}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
if response["residences"] == nil {
// null residences means no residences
} else {
residences := response["residences"].([]interface{})
assert.Len(t, residences, 0)
}
})
}
func TestResidenceHandler_GetSummary(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
testutil.CreateTestResidence(t, db, user.ID, "House 1")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/summary/", handler.GetSummary)
t.Run("successful summary", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/residences/summary/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "total_residences")
assert.Contains(t, response, "total_tasks")
})
}
func TestResidenceHandler_UpdateResidence_InvalidID(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.PUT("/:id/", handler.UpdateResidence)
t.Run("invalid id returns 400", func(t *testing.T) {
newName := "Updated"
req := requests.UpdateResidenceRequest{Name: &newName}
w := testutil.MakeRequest(e, "PUT", "/api/residences/invalid/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("non-existent id returns 403", func(t *testing.T) {
newName := "Updated"
req := requests.UpdateResidenceRequest{Name: &newName}
w := testutil.MakeRequest(e, "PUT", "/api/residences/9999/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusForbidden)
})
}
func TestResidenceHandler_DeleteResidence_InvalidID(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/:id/", handler.DeleteResidence)
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/residences/invalid/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("non-existent id returns 403", func(t *testing.T) {
w := testutil.MakeRequest(e, "DELETE", "/api/residences/9999/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusForbidden)
})
}
func TestResidenceHandler_GetShareCode(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Share Code Test")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/:id/share-code/", handler.GetShareCode)
t.Run("no share code returns null", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/residences/%d/share-code/", residence.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Nil(t, response["share_code"])
})
t.Run("invalid id returns 400", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/residences/invalid/share-code/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestResidenceHandler_GenerateSharePackage(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Package Test")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/:id/generate-share-package/", handler.GenerateSharePackage)
t.Run("generate share package", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/residences/%d/generate-share-package/", residence.ID), nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "share_code")
})
}

211
internal/i18n/i18n_test.go Normal file
View File

@@ -0,0 +1,211 @@
package i18n
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInit(t *testing.T) {
err := Init()
require.NoError(t, err)
assert.NotNil(t, Bundle)
}
func TestTSimple_EnglishKnownKey(t *testing.T) {
require.NoError(t, Init())
localizer := NewLocalizer("en")
msg := TSimple(localizer, "error.task_not_found")
assert.Equal(t, "Task not found", msg)
}
func TestTSimple_SpanishKnownKey(t *testing.T) {
require.NoError(t, Init())
localizer := NewLocalizer("es")
msg := TSimple(localizer, "error.invalid_credentials")
assert.Equal(t, "Credenciales no validas", msg)
}
func TestT_WithTemplateData(t *testing.T) {
require.NoError(t, Init())
localizer := NewLocalizer("en")
msg := T(localizer, "message.tasks_report_sent", map[string]interface{}{
"Email": "test@example.com",
})
assert.Contains(t, msg, "test@example.com")
}
func TestTSimple_UnknownKeyReturnsKey(t *testing.T) {
require.NoError(t, Init())
localizer := NewLocalizer("en")
key := "error.nonexistent_key_that_does_not_exist"
msg := TSimple(localizer, key)
assert.Equal(t, key, msg)
}
func TestTSimple_FallbackToEnglish(t *testing.T) {
require.NoError(t, Init())
// Use a language that may not have all translations — fallback to English
localizer := NewLocalizer("xx", "en")
msg := TSimple(localizer, "error.task_not_found")
assert.Equal(t, "Task not found", msg)
}
func TestT_NilLocalizer_UsesDefault(t *testing.T) {
require.NoError(t, Init())
msg := T(nil, "error.task_not_found", nil)
assert.Equal(t, "Task not found", msg)
}
func TestNewLocalizer(t *testing.T) {
require.NoError(t, Init())
localizer := NewLocalizer("en")
assert.NotNil(t, localizer)
}
func TestParseAcceptLanguage(t *testing.T) {
tests := []struct {
name string
header string
expected []string
}{
{"empty returns default", "", []string{"en"}},
{"english", "en-US,en;q=0.9", []string{"en", "en"}},
{"spanish first", "es,en;q=0.5", []string{"es", "en"}},
{"unsupported returns default", "xx-YY", []string{"en"}},
{"french", "fr-FR,fr;q=0.9,en;q=0.5", []string{"fr", "fr", "en"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := parseAcceptLanguage(tc.header)
assert.Equal(t, tc.expected, result)
})
}
}
func TestMatchLocale(t *testing.T) {
tests := []struct {
name string
langs []string
expected string
}{
{"finds supported", []string{"es", "en"}, "es"},
{"first match wins", []string{"fr", "de"}, "fr"},
{"unsupported returns default", []string{"xx"}, "en"},
{"empty returns default", []string{}, "en"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := matchLocale(tc.langs)
assert.Equal(t, tc.expected, result)
})
}
}
func TestMiddleware_SetsLocalizerAndLocale(t *testing.T) {
require.NoError(t, Init())
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept-Language", "es,en;q=0.5")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := Middleware()(func(c echo.Context) error {
// Verify localizer is set
localizer := GetLocalizer(c)
assert.NotNil(t, localizer)
// Verify locale is set
locale := GetLocale(c)
assert.Equal(t, "es", locale)
return nil
})
err := handler(c)
assert.NoError(t, err)
}
func TestGetLocalizer_NoContextValue_ReturnsDefault(t *testing.T) {
require.NoError(t, Init())
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
localizer := GetLocalizer(c)
assert.NotNil(t, localizer)
}
func TestGetLocale_NoContextValue_ReturnsDefault(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
locale := GetLocale(c)
assert.Equal(t, "en", locale)
}
func TestLocalizedMessage(t *testing.T) {
require.NoError(t, Init())
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept-Language", "en")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Set up localizer through middleware
handler := Middleware()(func(c echo.Context) error {
msg := LocalizedMessage(c, "error.task_not_found")
assert.Equal(t, "Task not found", msg)
return nil
})
err := handler(c)
assert.NoError(t, err)
}
func TestLocalizedMessageWithData(t *testing.T) {
require.NoError(t, Init())
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept-Language", "en")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := Middleware()(func(c echo.Context) error {
msg := LocalizedMessageWithData(c, "message.tasks_report_sent", map[string]interface{}{
"Email": "user@example.com",
})
assert.Contains(t, msg, "user@example.com")
return nil
})
err := handler(c)
assert.NoError(t, err)
}
func TestSupportedLanguages(t *testing.T) {
assert.Contains(t, SupportedLanguages, "en")
assert.Contains(t, SupportedLanguages, "es")
assert.Contains(t, SupportedLanguages, "fr")
assert.Equal(t, "en", DefaultLanguage)
}

View File

@@ -17,10 +17,10 @@ func TestIntegration_ContractorSharingFlow(t *testing.T) {
// ========== Setup Users ==========
// Create user A
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "password123")
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "Password123")
// Create user B
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "password123")
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "Password123")
// ========== User A creates residence C ==========
residenceBody := map[string]interface{}{
@@ -180,8 +180,8 @@ func TestIntegration_ContractorAccessWithoutResidenceShare(t *testing.T) {
app := setupContractorTest(t)
// Create two users
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "password123")
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "password123")
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "Password123")
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "Password123")
// User A creates a residence
residenceBody := map[string]interface{}{
@@ -228,9 +228,9 @@ func TestIntegration_ContractorUpdateAndDeleteAccess(t *testing.T) {
app := setupContractorTest(t)
// Create users
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "password123")
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "password123")
userCToken := app.registerAndLogin(t, "userC", "userC@test.com", "password123")
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "Password123")
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "Password123")
userCToken := app.registerAndLogin(t, "userC", "userC@test.com", "Password123")
// User A creates residence and shares with User B (not User C)
residenceBody := map[string]interface{}{"name": "Shared Residence"}

View File

@@ -379,7 +379,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
registerBody := map[string]string{
"username": "testuser",
"email": "test@example.com",
"password": "password123",
"password": "Password123",
}
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
assert.Equal(t, http.StatusCreated, w.Code)
@@ -388,7 +388,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
registerBody2 := map[string]string{
"username": "testuser",
"email": "different@example.com",
"password": "password123",
"password": "Password123",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody2, "")
assert.Equal(t, http.StatusConflict, w.Code)
@@ -397,7 +397,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
registerBody3 := map[string]string{
"username": "differentuser",
"email": "test@example.com",
"password": "password123",
"password": "Password123",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody3, "")
assert.Equal(t, http.StatusConflict, w.Code)
@@ -407,7 +407,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
func TestIntegration_ResidenceFlow(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
token := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
// 1. Create a residence
createBody := map[string]interface{}{
@@ -475,8 +475,8 @@ func TestIntegration_ResidenceSharingFlow(t *testing.T) {
app := setupIntegrationTest(t)
// Create owner and another user
ownerToken := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
userToken := app.registerAndLogin(t, "shareduser", "shared@test.com", "password123")
ownerToken := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
userToken := app.registerAndLogin(t, "shareduser", "shared@test.com", "Password123")
// Create residence as owner
createBody := map[string]interface{}{
@@ -531,7 +531,7 @@ func TestIntegration_ResidenceSharingFlow(t *testing.T) {
func TestIntegration_TaskFlow(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
token := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
// Create residence first
residenceBody := map[string]interface{}{"name": "Task House"}
@@ -633,7 +633,7 @@ func TestIntegration_TaskFlow(t *testing.T) {
func TestIntegration_TasksByResidenceKanban(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
token := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
// Use explicit timezone to test full timezone-aware path
testTimezone := "America/Los_Angeles"
@@ -682,7 +682,7 @@ func TestIntegration_TasksByResidenceKanban(t *testing.T) {
func TestIntegration_LookupEndpoints(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "user", "user@test.com", "password123")
token := app.registerAndLogin(t, "user", "user@test.com", "Password123")
tests := []struct {
name string
@@ -721,8 +721,8 @@ func TestIntegration_CrossUserAccessDenied(t *testing.T) {
app := setupIntegrationTest(t)
// Create two users with their own residences
user1Token := app.registerAndLogin(t, "user1", "user1@test.com", "password123")
user2Token := app.registerAndLogin(t, "user2", "user2@test.com", "password123")
user1Token := app.registerAndLogin(t, "user1", "user1@test.com", "Password123")
user2Token := app.registerAndLogin(t, "user2", "user2@test.com", "Password123")
// User1 creates a residence
residenceBody := map[string]interface{}{"name": "User1's House"}
@@ -777,7 +777,7 @@ func TestIntegration_CrossUserAccessDenied(t *testing.T) {
func TestIntegration_ResponseStructure(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "user", "user@test.com", "password123")
token := app.registerAndLogin(t, "user", "user@test.com", "Password123")
// Create residence
residenceBody := map[string]interface{}{
@@ -1704,7 +1704,7 @@ func setupContractorTest(t *testing.T) *TestApp {
// - Verify task moves between kanban columns appropriately
func TestIntegration_RecurringTaskLifecycle(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "recurring_user", "recurring@test.com", "password123")
token := app.registerAndLogin(t, "recurring_user", "recurring@test.com", "Password123")
// Create residence
residenceBody := map[string]interface{}{"name": "Recurring Task House"}
@@ -1904,9 +1904,9 @@ func TestIntegration_MultiUserSharing(t *testing.T) {
t.Log("Phase 1: Create 3 users")
tokenA := app.registerAndLogin(t, "user_a", "usera@test.com", "password123")
tokenB := app.registerAndLogin(t, "user_b", "userb@test.com", "password123")
tokenC := app.registerAndLogin(t, "user_c", "userc@test.com", "password123")
tokenA := app.registerAndLogin(t, "user_a", "usera@test.com", "Password123")
tokenB := app.registerAndLogin(t, "user_b", "userb@test.com", "Password123")
tokenC := app.registerAndLogin(t, "user_c", "userc@test.com", "Password123")
t.Log("✓ Created users A, B, and C")
@@ -2098,7 +2098,7 @@ func TestIntegration_MultiUserSharing(t *testing.T) {
// - Verify kanban column changes with each transition
func TestIntegration_TaskStateTransitions(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "state_user", "state@test.com", "password123")
token := app.registerAndLogin(t, "state_user", "state@test.com", "Password123")
// Create residence
residenceBody := map[string]interface{}{"name": "State Transition House"}
@@ -2274,7 +2274,7 @@ func TestIntegration_TaskStateTransitions(t *testing.T) {
// we're testing the full timezone-aware path, not just UTC defaults.
func TestIntegration_DateBoundaryEdgeCases(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "boundary_user", "boundary@test.com", "password123")
token := app.registerAndLogin(t, "boundary_user", "boundary@test.com", "Password123")
// Create residence
residenceBody := map[string]interface{}{"name": "Boundary Test House"}
@@ -2435,7 +2435,7 @@ func TestIntegration_DateBoundaryEdgeCases(t *testing.T) {
// - One where it's already "tomorrow" → task is overdue (due date was "yesterday")
func TestIntegration_TimezoneDivergence(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "tz_user", "tz@test.com", "password123")
token := app.registerAndLogin(t, "tz_user", "tz@test.com", "Password123")
// Create residence
residenceBody := map[string]interface{}{"name": "Timezone Test House"}
@@ -2584,7 +2584,7 @@ func findTaskColumn(kanbanResp map[string]interface{}, taskID uint) string {
// - Verify cascading effects
func TestIntegration_CascadeOperations(t *testing.T) {
app := setupIntegrationTest(t)
token := app.registerAndLogin(t, "cascade_user", "cascade@test.com", "password123")
token := app.registerAndLogin(t, "cascade_user", "cascade@test.com", "Password123")
t.Log("Phase 1: Create residence")
@@ -2721,8 +2721,8 @@ func TestIntegration_MultiUserOperations(t *testing.T) {
t.Log("Phase 1: Setup users and shared residence")
tokenA := app.registerAndLogin(t, "multiuser_a", "multiusera@test.com", "password123")
tokenB := app.registerAndLogin(t, "multiuser_b", "multiuserb@test.com", "password123")
tokenA := app.registerAndLogin(t, "multiuser_a", "multiusera@test.com", "Password123")
tokenB := app.registerAndLogin(t, "multiuser_b", "multiuserb@test.com", "Password123")
// User A creates residence
residenceBody := map[string]interface{}{"name": "Multi-User Test House"}

View File

@@ -265,8 +265,8 @@ func TestE2E_SQLInjection_AdminSort_Blocked(t *testing.T) {
adminUserHandler := adminhandlers.NewAdminUserHandler(db)
// Create a couple of test users to have data to sort
testutil.CreateTestUser(t, db, "alice", "alice@test.com", "password123")
testutil.CreateTestUser(t, db, "bob", "bob@test.com", "password123")
testutil.CreateTestUser(t, db, "alice", "alice@test.com", "Password123")
testutil.CreateTestUser(t, db, "bob", "bob@test.com", "Password123")
// Set up a minimal Echo instance with the admin handler
e := echo.New()
@@ -322,7 +322,7 @@ func TestE2E_SQLInjection_AdminSort_Blocked(t *testing.T) {
// garbage receipt data does NOT upgrade the user to Pro tier.
func TestE2E_IAP_InvalidReceipt_NoPro(t *testing.T) {
app := setupSecurityTest(t)
token, userID := app.registerAndLoginSec(t, "iapuser", "iap@test.com", "password123")
token, userID := app.registerAndLoginSec(t, "iapuser", "iap@test.com", "Password123")
// Create initial subscription (free tier)
sub := &models.UserSubscription{UserID: userID, Tier: models.TierFree}
@@ -352,7 +352,7 @@ func TestE2E_IAP_InvalidReceipt_NoPro(t *testing.T) {
// updates both the completion record and the task's NextDueDate together (P1-5/P1-6).
func TestE2E_CompletionTransaction_Atomic(t *testing.T) {
app := setupSecurityTest(t)
token, _ := app.registerAndLoginSec(t, "atomicuser", "atomic@test.com", "password123")
token, _ := app.registerAndLoginSec(t, "atomicuser", "atomic@test.com", "Password123")
// Create a residence
residenceBody := map[string]interface{}{"name": "Atomic Test House"}
@@ -423,7 +423,7 @@ func TestE2E_CompletionTransaction_Atomic(t *testing.T) {
// on a recurring task recalculates NextDueDate back to the correct value (P1-7).
func TestE2E_DeleteCompletion_RecalculatesNextDueDate(t *testing.T) {
app := setupSecurityTest(t)
token, _ := app.registerAndLoginSec(t, "recuruser", "recur@test.com", "password123")
token, _ := app.registerAndLoginSec(t, "recuruser", "recur@test.com", "Password123")
// Create a residence
residenceBody := map[string]interface{}{"name": "Recurring Test House"}
@@ -510,7 +510,7 @@ func TestE2E_DeleteCompletion_RecalculatesNextDueDate(t *testing.T) {
// configured property limit.
func TestE2E_TierLimits_Enforced(t *testing.T) {
app := setupSecurityTest(t)
token, userID := app.registerAndLoginSec(t, "tieruser", "tier@test.com", "password123")
token, userID := app.registerAndLoginSec(t, "tieruser", "tier@test.com", "Password123")
// Enable global limitations
app.DB.Where("1=1").Delete(&models.SubscriptionSettings{})
@@ -602,7 +602,7 @@ func TestE2E_AuthAssertion_NoPanics(t *testing.T) {
// caps the limit parameter to 200 even if the client requests more.
func TestE2E_NotificationLimit_Capped(t *testing.T) {
app := setupSecurityTest(t)
token, userID := app.registerAndLoginSec(t, "notifuser", "notif@test.com", "password123")
token, userID := app.registerAndLoginSec(t, "notifuser", "notif@test.com", "Password123")
// Create 210 notifications directly in the database
for i := 0; i < 210; i++ {

View File

@@ -164,7 +164,7 @@ func TestIntegration_IsFreeBypassesLimitations(t *testing.T) {
app := setupSubscriptionTest(t)
// Register and login a user
token, userID := app.registerAndLogin(t, "freeuser", "free@test.com", "password123")
token, userID := app.registerAndLogin(t, "freeuser", "free@test.com", "Password123")
// Enable global limitations - first delete any existing, then create with enabled
app.DB.Where("1=1").Delete(&models.SubscriptionSettings{})
@@ -215,7 +215,7 @@ func TestIntegration_IsFreeBypassesCheckLimit(t *testing.T) {
app := setupSubscriptionTest(t)
// Register and login a user
_, userID := app.registerAndLogin(t, "limituser", "limit@test.com", "password123")
_, userID := app.registerAndLogin(t, "limituser", "limit@test.com", "Password123")
// Enable global limitations
settings := &models.SubscriptionSettings{EnableLimitations: true}
@@ -282,7 +282,7 @@ func TestIntegration_IsFreeIndependentOfTier(t *testing.T) {
app := setupSubscriptionTest(t)
// Register and login a user
token, userID := app.registerAndLogin(t, "tieruser", "tier@test.com", "password123")
token, userID := app.registerAndLogin(t, "tieruser", "tier@test.com", "Password123")
// Enable global limitations
settings := &models.SubscriptionSettings{EnableLimitations: true}
@@ -340,7 +340,7 @@ func TestIntegration_IsFreeWhenGlobalLimitationsDisabled(t *testing.T) {
app := setupSubscriptionTest(t)
// Register and login a user
token, userID := app.registerAndLogin(t, "globaluser", "global@test.com", "password123")
token, userID := app.registerAndLogin(t, "globaluser", "global@test.com", "Password123")
// Disable global limitations
settings := &models.SubscriptionSettings{EnableLimitations: false}

View File

@@ -0,0 +1,163 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
)
func TestAdminAuth_NoHeader_Returns401(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
mw := AdminAuthMiddleware(cfg, nil)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Contains(t, rec.Body.String(), "Authorization required")
}
func TestAdminAuth_InvalidToken_Returns401(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
mw := AdminAuthMiddleware(cfg, nil)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
req.Header.Set("Authorization", "Bearer invalid-jwt-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid token")
}
func TestAdminAuth_TokenSchemeOnly_Returns401(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
mw := AdminAuthMiddleware(cfg, nil)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
// "Token" scheme is not supported for admin auth, only "Bearer"
req.Header.Set("Authorization", "Token some-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestRequireSuperAdmin_NoAdmin_Returns401(t *testing.T) {
mw := RequireSuperAdmin()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// No admin in context
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestRequireSuperAdmin_WrongType_Returns401(t *testing.T) {
mw := RequireSuperAdmin()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Wrong type in context
c.Set(AdminUserKey, "not-an-admin")
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestRequireSuperAdmin_NonSuperAdmin_Returns403(t *testing.T) {
mw := RequireSuperAdmin()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Regular admin (not super admin)
admin := &models.AdminUser{
Email: "admin@test.com",
IsActive: true,
Role: models.AdminRoleAdmin,
}
c.Set(AdminUserKey, admin)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, rec.Code)
assert.Contains(t, rec.Body.String(), "Super admin privileges required")
}
func TestRequireSuperAdmin_SuperAdmin_Passes(t *testing.T) {
mw := RequireSuperAdmin()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
admin := &models.AdminUser{
Email: "superadmin@test.com",
IsActive: true,
Role: models.AdminRoleSuperAdmin,
}
c.Set(AdminUserKey, admin)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -41,7 +41,7 @@ func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays
Email: username + "@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("password123"))
require.NoError(t, user.SetPassword("Password123"))
require.NoError(t, db.Create(user).Error)
token := &models.AuthToken{

View File

@@ -0,0 +1,337 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
)
func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "bearer_user", 10)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Bearer "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
user := GetAuthUser(c)
require.NotNil(t, user)
assert.Equal(t, "bearer_user", user.Username)
}
func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "scheme_user", 10)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Basic "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "JustATokenWithNoScheme")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
func TestTokenAuth_EmptyToken_Rejected(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token ")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
db := setupTestDB(t)
user, token := createTestUserAndToken(t, db, "inactive_user", 10)
// Deactivate the user
require.NoError(t, db.Model(user).Update("is_active", false).Error)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
// No Authorization header
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-user", rec.Body.String())
}
func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "opt_user", 10)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "opt_user", rec.Body.String())
}
func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "expired_opt_user", 91)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-user", rec.Body.String())
}
func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token nonexistent-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-user", rec.Body.String())
}
func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
db := setupTestDB(t)
cfg := &config.Config{
Security: config.SecurityConfig{
TokenExpiryDays: 30,
},
}
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
assert.NotNil(t, m)
assert.Equal(t, 30, m.tokenExpiryDays)
// Token at 29 days should be valid
_, token := createTestUserAndToken(t, db, "short_expiry_user", 29)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
db := setupTestDB(t)
cfg := &config.Config{
Security: config.SecurityConfig{
TokenExpiryDays: 30,
},
}
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
// Token at 31 days should be expired with 30-day config
_, token := createTestUserAndToken(t, db, "custom_expired_user", 31)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddlewareWithConfig(db, nil, nil)
assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays)
}
func TestGetAuthToken_ReturnsToken(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(AuthTokenKey, "test-token-value")
assert.Equal(t, "test-token-value", GetAuthToken(c))
}
func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// No token set
assert.Equal(t, "", GetAuthToken(c))
}
func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(AuthTokenKey, 12345) // Wrong type
assert.Equal(t, "", GetAuthToken(c))
}
func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
// Legacy tokens without created time should not be expired
assert.False(t, m.isTokenExpired(models.AuthToken{}.Created))
}
func TestInvalidateToken_NilCache_NoError(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil) // nil cache
err := m.InvalidateToken(nil, "some-token")
assert.NoError(t, err)
}

View File

@@ -0,0 +1,93 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostCheck_AllowedHost_Passes(t *testing.T) {
mw := HostCheck([]string{"api.example.com", "localhost:8000"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "api.example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestHostCheck_DisallowedHost_Returns403(t *testing.T) {
mw := HostCheck([]string{"api.example.com"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "evil.example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, rec.Code)
var response map[string]interface{}
err = json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "Forbidden", response["error"])
}
func TestHostCheck_EmptyAllowedHosts_AllPass(t *testing.T) {
mw := HostCheck([]string{})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "any-host.example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestHostCheck_LocalhostWithPort_Passes(t *testing.T) {
mw := HostCheck([]string{"localhost:8000"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "localhost:8000"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestHostCheck_LocalhostWithoutPort_Denied(t *testing.T) {
// Only "localhost:8000" allowed, not plain "localhost"
mw := HostCheck([]string{"localhost:8000"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "localhost"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, rec.Code)
}

View File

@@ -0,0 +1,103 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
)
func TestStructuredLogger_Passes_Request(t *testing.T) {
mw := StructuredLogger()
e := echo.New()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "ok", rec.Body.String())
}
func TestStructuredLogger_WithUser(t *testing.T) {
mw := StructuredLogger()
e := echo.New()
user := &models.User{Username: "loguser"}
user.ID = 42
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(AuthUserKey, user)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestStructuredLogger_WithRequestID(t *testing.T) {
mw := StructuredLogger()
e := echo.New()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(ContextKeyRequestID, "test-request-id-123")
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestStructuredLogger_ErrorStatus(t *testing.T) {
mw := StructuredLogger()
e := echo.New()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusInternalServerError, "error")
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestStructuredLogger_ClientError(t *testing.T) {
mw := StructuredLogger()
e := echo.New()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusBadRequest, "bad request")
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -0,0 +1,222 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTimezoneMiddleware_IANATimezone(t *testing.T) {
mw := TimezoneMiddleware()
e := echo.New()
handler := mw(func(c echo.Context) error {
loc := GetUserTimezone(c)
return c.String(http.StatusOK, loc.String())
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set(TimezoneHeader, "America/New_York")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, "America/New_York", rec.Body.String())
}
func TestTimezoneMiddleware_UTCOffset(t *testing.T) {
mw := TimezoneMiddleware()
e := echo.New()
handler := mw(func(c echo.Context) error {
loc := GetUserTimezone(c)
// Just verify it's not UTC (if offset is non-zero)
return c.String(http.StatusOK, loc.String())
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set(TimezoneHeader, "-05:00")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, "-05:00", rec.Body.String())
}
func TestTimezoneMiddleware_NoHeader_DefaultsToUTC(t *testing.T) {
mw := TimezoneMiddleware()
e := echo.New()
handler := mw(func(c echo.Context) error {
loc := GetUserTimezone(c)
return c.String(http.StatusOK, loc.String())
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
// No X-Timezone header
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, "UTC", rec.Body.String())
}
func TestTimezoneMiddleware_InvalidTimezone_DefaultsToUTC(t *testing.T) {
mw := TimezoneMiddleware()
e := echo.New()
handler := mw(func(c echo.Context) error {
loc := GetUserTimezone(c)
return c.String(http.StatusOK, loc.String())
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set(TimezoneHeader, "Invalid/Timezone")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, "UTC", rec.Body.String())
}
func TestTimezoneMiddleware_SetsUserNow(t *testing.T) {
mw := TimezoneMiddleware()
e := echo.New()
var capturedNow time.Time
handler := mw(func(c echo.Context) error {
capturedNow = GetUserNow(c)
return c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set(TimezoneHeader, "America/Chicago")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
// The "now" should be the start of the day in the user's timezone
assert.Equal(t, 0, capturedNow.Hour())
assert.Equal(t, 0, capturedNow.Minute())
assert.Equal(t, 0, capturedNow.Second())
}
func TestTimezoneMiddleware_SetsTimezoneName(t *testing.T) {
mw := TimezoneMiddleware()
e := echo.New()
handler := mw(func(c echo.Context) error {
name := GetTimezoneName(c)
return c.String(http.StatusOK, name)
})
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set(TimezoneHeader, "Europe/London")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, "Europe/London", rec.Body.String())
}
func TestGetUserTimezone_NotSet_ReturnsUTC(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// No timezone set in context
loc := GetUserTimezone(c)
assert.Equal(t, time.UTC, loc)
}
func TestGetUserTimezone_WrongType_ReturnsUTC(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(TimezoneKey, "not-a-location")
loc := GetUserTimezone(c)
assert.Equal(t, time.UTC, loc)
}
func TestGetUserNow_NotSet_ReturnsUTCNow(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
before := time.Now().UTC()
now := GetUserNow(c)
after := time.Now().UTC()
assert.True(t, !now.Before(before.Add(-time.Second)), "now should be roughly after before")
assert.True(t, !now.After(after.Add(time.Second)), "now should be roughly before after")
}
func TestGetUserNow_WrongType_ReturnsUTCNow(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(UserNowKey, "not-a-time")
now := GetUserNow(c)
assert.NotNil(t, now)
}
func TestIsTimezoneChanged_NoChange_ReturnsFalse(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(TimezoneChangedKey, false)
assert.False(t, IsTimezoneChanged(c))
}
func TestIsTimezoneChanged_Changed_ReturnsTrue(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(TimezoneChangedKey, true)
assert.True(t, IsTimezoneChanged(c))
}
func TestIsTimezoneChanged_NotSet_ReturnsFalse(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Not set at all
assert.False(t, IsTimezoneChanged(c))
}
func TestParseTimezone_UTCOffsetWithoutColon(t *testing.T) {
loc := parseTimezone("-0800")
assert.NotEqual(t, time.UTC, loc)
assert.Equal(t, "-0800", loc.String())
}
func TestParseTimezone_PositiveOffset(t *testing.T) {
loc := parseTimezone("+05:30")
assert.NotEqual(t, time.UTC, loc)
assert.Equal(t, "+05:30", loc.String())
}
func TestParseTimezone_UTC(t *testing.T) {
loc := parseTimezone("UTC")
assert.Equal(t, time.UTC, loc)
}

View File

@@ -0,0 +1,186 @@
package middleware
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
)
func TestUserCache_SetAndGet(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
user := &models.User{Username: "testuser", Email: "test@test.com"}
user.ID = 1
cache.Set(user)
cached := cache.Get(1)
require.NotNil(t, cached)
assert.Equal(t, "testuser", cached.Username)
assert.Equal(t, "test@test.com", cached.Email)
}
func TestUserCache_GetNonExistent_ReturnsNil(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
cached := cache.Get(999)
assert.Nil(t, cached)
}
func TestUserCache_Expired_ReturnsNil(t *testing.T) {
// Very short TTL
cache := NewUserCache(1 * time.Millisecond)
user := &models.User{Username: "expiring_user"}
user.ID = 1
cache.Set(user)
// Wait for expiry
time.Sleep(5 * time.Millisecond)
cached := cache.Get(1)
assert.Nil(t, cached, "expired entry should return nil")
}
func TestUserCache_Invalidate(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
user := &models.User{Username: "to_invalidate"}
user.ID = 1
cache.Set(user)
// Verify it's cached
require.NotNil(t, cache.Get(1))
// Invalidate
cache.Invalidate(1)
// Should be gone
assert.Nil(t, cache.Get(1))
}
func TestUserCache_ReturnsCopy_NotOriginal(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
user := &models.User{Username: "original"}
user.ID = 1
cache.Set(user)
// Modify the returned copy
cached := cache.Get(1)
require.NotNil(t, cached)
cached.Username = "modified"
// Original cache entry should be unaffected
cached2 := cache.Get(1)
require.NotNil(t, cached2)
assert.Equal(t, "original", cached2.Username, "cache should return a copy, not the original")
}
func TestUserCache_SetCopiesInput(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
user := &models.User{Username: "original"}
user.ID = 1
cache.Set(user)
// Modify the input after setting
user.Username = "modified_after_set"
// Cache should still have the original value
cached := cache.Get(1)
require.NotNil(t, cached)
assert.Equal(t, "original", cached.Username, "cache should store a copy of the input")
}
func TestUserCache_MultipleUsers(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
user1 := &models.User{Username: "user1"}
user1.ID = 1
user2 := &models.User{Username: "user2"}
user2.ID = 2
cache.Set(user1)
cache.Set(user2)
cached1 := cache.Get(1)
cached2 := cache.Get(2)
require.NotNil(t, cached1)
require.NotNil(t, cached2)
assert.Equal(t, "user1", cached1.Username)
assert.Equal(t, "user2", cached2.Username)
}
func TestUserCache_OverwriteEntry(t *testing.T) {
cache := NewUserCache(1 * time.Minute)
user := &models.User{Username: "original"}
user.ID = 1
cache.Set(user)
// Overwrite with new data
updated := &models.User{Username: "updated"}
updated.ID = 1
cache.Set(updated)
cached := cache.Get(1)
require.NotNil(t, cached)
assert.Equal(t, "updated", cached.Username)
}
func TestTimezoneCache_GetAndCompare_NewEntry(t *testing.T) {
tc := NewTimezoneCache()
// First call should return false (not cached yet)
unchanged := tc.GetAndCompare(1, "America/New_York")
assert.False(t, unchanged, "first call should indicate a change")
}
func TestTimezoneCache_GetAndCompare_SameValue(t *testing.T) {
tc := NewTimezoneCache()
// First call sets the value
tc.GetAndCompare(1, "America/New_York")
// Second call with same value should return true (unchanged)
unchanged := tc.GetAndCompare(1, "America/New_York")
assert.True(t, unchanged, "same value should indicate no change")
}
func TestTimezoneCache_GetAndCompare_DifferentValue(t *testing.T) {
tc := NewTimezoneCache()
// Set initial value
tc.GetAndCompare(1, "America/New_York")
// Update to different value
unchanged := tc.GetAndCompare(1, "America/Chicago")
assert.False(t, unchanged, "different value should indicate a change")
// Now the new value is cached
unchanged = tc.GetAndCompare(1, "America/Chicago")
assert.True(t, unchanged, "same value should indicate no change")
}
func TestTimezoneCache_GetAndCompare_DifferentUsers(t *testing.T) {
tc := NewTimezoneCache()
tc.GetAndCompare(1, "America/New_York")
tc.GetAndCompare(2, "Europe/London")
assert.True(t, tc.GetAndCompare(1, "America/New_York"))
assert.True(t, tc.GetAndCompare(2, "Europe/London"))
assert.False(t, tc.GetAndCompare(1, "Europe/London"))
}

View File

@@ -0,0 +1,626 @@
package models
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// setupModelsTestDB creates a minimal in-memory SQLite for model-level tests
// that require database interaction (e.g., BeforeCreate hooks).
func setupModelsTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&User{}, &AuthToken{}, &UserProfile{})
require.NoError(t, err)
return db
}
// === Residence model tests ===
func TestResidence_GetAllUsers(t *testing.T) {
owner := User{Username: "owner"}
owner.ID = 1
member1 := User{Username: "member1"}
member1.ID = 2
member2 := User{Username: "member2"}
member2.ID = 3
residence := &Residence{
OwnerID: owner.ID,
Owner: owner,
Users: []User{member1, member2},
}
allUsers := residence.GetAllUsers()
assert.Len(t, allUsers, 3)
assert.Equal(t, "owner", allUsers[0].Username)
}
func TestResidence_HasAccess(t *testing.T) {
owner := User{Username: "owner"}
owner.ID = 1
member := User{Username: "member"}
member.ID = 2
residence := &Residence{
OwnerID: owner.ID,
Owner: owner,
Users: []User{member},
}
tests := []struct {
name string
userID uint
expected bool
}{
{"owner has access", 1, true},
{"member has access", 2, true},
{"stranger no access", 99, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, residence.HasAccess(tt.userID))
})
}
}
func TestResidence_IsPrimaryOwner(t *testing.T) {
residence := &Residence{OwnerID: 1}
assert.True(t, residence.IsPrimaryOwner(1))
assert.False(t, residence.IsPrimaryOwner(2))
}
// === Document model tests ===
func TestDocument_TableName_DocumentImage(t *testing.T) {
di := DocumentImage{}
assert.Equal(t, "task_documentimage", di.TableName())
}
func TestDocument_IsWarrantyExpiringSoon(t *testing.T) {
future30 := time.Now().UTC().AddDate(0, 0, 15)
future90 := time.Now().UTC().AddDate(0, 0, 60)
past := time.Now().UTC().AddDate(0, 0, -5)
tests := []struct {
name string
doc Document
days int
expected bool
}{
{
name: "warranty expiring within threshold",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &future30},
days: 30,
expected: true,
},
{
name: "warranty not expiring within threshold",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &future90},
days: 30,
expected: false,
},
{
name: "warranty already expired",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &past},
days: 30,
expected: false,
},
{
name: "non-warranty document",
doc: Document{DocumentType: DocumentTypeGeneral, ExpiryDate: &future30},
days: 30,
expected: false,
},
{
name: "warranty with nil expiry",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: nil},
days: 30,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.doc.IsWarrantyExpiringSoon(tt.days))
})
}
}
func TestDocument_IsWarrantyExpired(t *testing.T) {
past := time.Now().UTC().AddDate(0, 0, -5)
future := time.Now().UTC().AddDate(0, 0, 30)
tests := []struct {
name string
doc Document
expected bool
}{
{
name: "expired warranty",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &past},
expected: true,
},
{
name: "active warranty",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &future},
expected: false,
},
{
name: "non-warranty",
doc: Document{DocumentType: DocumentTypeGeneral, ExpiryDate: &past},
expected: false,
},
{
name: "warranty nil expiry",
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.doc.IsWarrantyExpired())
})
}
}
func TestDocumentType_Constants(t *testing.T) {
// Verify document type constants have expected values
assert.Equal(t, DocumentType("general"), DocumentTypeGeneral)
assert.Equal(t, DocumentType("warranty"), DocumentTypeWarranty)
assert.Equal(t, DocumentType("receipt"), DocumentTypeReceipt)
assert.Equal(t, DocumentType("contract"), DocumentTypeContract)
assert.Equal(t, DocumentType("insurance"), DocumentTypeInsurance)
assert.Equal(t, DocumentType("manual"), DocumentTypeManual)
}
// === Notification model tests ===
func TestNotification_TableName(t *testing.T) {
n := Notification{}
assert.Equal(t, "notifications_notification", n.TableName())
}
func TestNotificationPreference_TableName(t *testing.T) {
np := NotificationPreference{}
assert.Equal(t, "notifications_notificationpreference", np.TableName())
}
func TestAPNSDevice_TableName(t *testing.T) {
d := APNSDevice{}
assert.Equal(t, "push_notifications_apnsdevice", d.TableName())
}
func TestGCMDevice_TableName(t *testing.T) {
d := GCMDevice{}
assert.Equal(t, "push_notifications_gcmdevice", d.TableName())
}
func TestNotification_MarkAsRead(t *testing.T) {
n := &Notification{Read: false}
n.MarkAsRead()
assert.True(t, n.Read)
assert.NotNil(t, n.ReadAt)
}
func TestNotification_MarkAsSent(t *testing.T) {
n := &Notification{Sent: false}
n.MarkAsSent()
assert.True(t, n.Sent)
assert.NotNil(t, n.SentAt)
}
func TestNotificationType_Constants(t *testing.T) {
assert.Equal(t, NotificationType("task_due_soon"), NotificationTaskDueSoon)
assert.Equal(t, NotificationType("task_overdue"), NotificationTaskOverdue)
assert.Equal(t, NotificationType("task_completed"), NotificationTaskCompleted)
assert.Equal(t, NotificationType("task_assigned"), NotificationTaskAssigned)
assert.Equal(t, NotificationType("residence_shared"), NotificationResidenceShared)
assert.Equal(t, NotificationType("warranty_expiring"), NotificationWarrantyExpiring)
}
// === AuthToken model tests ===
func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "tokenuser",
Email: "token@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
token := &AuthToken{UserID: user.ID}
err = db.Create(token).Error
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Len(t, token.Key, 40) // 20 bytes = 40 hex chars
assert.False(t, token.Created.IsZero())
}
func TestAuthToken_BeforeCreate_PreservesExistingKey(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "tokenuser",
Email: "token@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
existingKey := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
token := &AuthToken{
Key: existingKey,
UserID: user.ID,
}
err = db.Create(token).Error
require.NoError(t, err)
assert.Equal(t, existingKey, token.Key)
}
func TestGetOrCreateToken_CreatesNew(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "newtoken",
Email: "newtoken@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
token, err := GetOrCreateToken(db, user.ID)
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Equal(t, user.ID, token.UserID)
}
func TestGetOrCreateToken_ReturnsExisting(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "existingtoken",
Email: "existingtoken@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
token1, err := GetOrCreateToken(db, user.ID)
require.NoError(t, err)
token2, err := GetOrCreateToken(db, user.ID)
require.NoError(t, err)
assert.Equal(t, token1.Key, token2.Key)
}
// === User model additional tests ===
func TestUser_SetPassword_And_CheckPassword_Integration(t *testing.T) {
user := &User{}
err := user.SetPassword("Password123")
require.NoError(t, err)
assert.True(t, user.CheckPassword("Password123"))
assert.False(t, user.CheckPassword("WrongPassword"))
assert.False(t, user.CheckPassword(""))
assert.False(t, user.CheckPassword("password123")) // case sensitive
}
// === Task model additional tests ===
func TestTask_IsOverdue_CancelledNotOverdue(t *testing.T) {
yesterday := time.Now().UTC().AddDate(0, 0, -2)
task := &Task{
DueDate: &yesterday,
IsCancelled: true,
}
assert.False(t, task.IsOverdue())
}
func TestTask_IsOverdue_ArchivedNotOverdue(t *testing.T) {
yesterday := time.Now().UTC().AddDate(0, 0, -2)
task := &Task{
DueDate: &yesterday,
IsArchived: true,
}
assert.False(t, task.IsOverdue())
}
func TestTask_IsOverdue_NoDueDateNotOverdue(t *testing.T) {
task := &Task{
DueDate: nil,
}
assert.False(t, task.IsOverdue())
}
func TestTask_IsOverdue_CompletedNotOverdue(t *testing.T) {
yesterday := time.Now().UTC().AddDate(0, 0, -2)
task := &Task{
DueDate: &yesterday,
NextDueDate: nil,
Completions: []TaskCompletion{{CompletedAt: time.Now().UTC()}},
}
assert.False(t, task.IsOverdue())
}
func TestTask_IsOverdue_CompletionCountNotOverdue(t *testing.T) {
yesterday := time.Now().UTC().AddDate(0, 0, -2)
task := &Task{
DueDate: &yesterday,
NextDueDate: nil,
CompletionCount: 1,
}
assert.False(t, task.IsOverdue())
}
func TestTask_IsOverdue_UsesNextDueDate(t *testing.T) {
// DueDate is overdue, but NextDueDate is in the future
pastDue := time.Now().UTC().AddDate(0, 0, -10)
futureDue := time.Now().UTC().AddDate(0, 0, 10)
task := &Task{
DueDate: &pastDue,
NextDueDate: &futureDue,
}
assert.False(t, task.IsOverdue())
}
func TestTask_IsDueSoon_CancelledNotDueSoon(t *testing.T) {
futureDue := time.Now().UTC().AddDate(0, 0, 5)
task := &Task{
DueDate: &futureDue,
IsCancelled: true,
}
assert.False(t, task.IsDueSoon(30))
}
func TestTask_IsDueSoon_NoDueDateNotDueSoon(t *testing.T) {
task := &Task{
DueDate: nil,
}
assert.False(t, task.IsDueSoon(30))
}
func TestTask_IsDueSoon_WithinThreshold(t *testing.T) {
futureDue := time.Now().UTC().AddDate(0, 0, 5)
task := &Task{
DueDate: &futureDue,
}
assert.True(t, task.IsDueSoon(30))
assert.True(t, task.IsDueSoon(10))
assert.False(t, task.IsDueSoon(3))
}
func TestTask_IsDueSoon_CompletedNotDueSoon(t *testing.T) {
futureDue := time.Now().UTC().AddDate(0, 0, 5)
task := &Task{
DueDate: &futureDue,
NextDueDate: nil,
Completions: []TaskCompletion{{CompletedAt: time.Now().UTC()}},
}
assert.False(t, task.IsDueSoon(30))
}
func TestTaskCompletionImage_TableName(t *testing.T) {
tci := TaskCompletionImage{}
assert.Equal(t, "task_taskcompletionimage", tci.TableName())
}
// === Subscription model additional tests ===
func TestSubscription_TableNames(t *testing.T) {
assert.Equal(t, "subscription_subscriptionsettings", SubscriptionSettings{}.TableName())
assert.Equal(t, "subscription_usersubscription", UserSubscription{}.TableName())
assert.Equal(t, "subscription_upgradetrigger", UpgradeTrigger{}.TableName())
assert.Equal(t, "subscription_featurebenefit", FeatureBenefit{}.TableName())
assert.Equal(t, "subscription_promotion", Promotion{}.TableName())
assert.Equal(t, "subscription_tierlimits", TierLimits{}.TableName())
}
func TestSubscription_IsActive(t *testing.T) {
future := time.Now().UTC().Add(24 * time.Hour)
past := time.Now().UTC().Add(-24 * time.Hour)
tests := []struct {
name string
sub *UserSubscription
expected bool
}{
{
name: "pro with future expiry is active",
sub: &UserSubscription{Tier: TierPro, ExpiresAt: &future},
expected: true,
},
{
name: "pro with nil expiry is active",
sub: &UserSubscription{Tier: TierPro, ExpiresAt: nil},
expected: true,
},
{
name: "pro with past expiry is not active",
sub: &UserSubscription{Tier: TierPro, ExpiresAt: &past},
expected: false,
},
{
name: "free with active trial is active",
sub: &UserSubscription{Tier: TierFree, TrialEnd: &future},
expected: true,
},
{
name: "free without trial is not active",
sub: &UserSubscription{Tier: TierFree},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.sub.IsActive())
})
}
}
func TestSubscription_SubscriptionSource(t *testing.T) {
sub := &UserSubscription{Platform: "ios"}
assert.Equal(t, "ios", sub.SubscriptionSource())
}
func TestPromotion_IsCurrentlyActive(t *testing.T) {
now := time.Now().UTC()
tests := []struct {
name string
promo Promotion
expected bool
}{
{
name: "active promotion within dates",
promo: Promotion{
IsActive: true,
StartDate: now.Add(-1 * time.Hour),
EndDate: now.Add(1 * time.Hour),
},
expected: true,
},
{
name: "inactive promotion",
promo: Promotion{
IsActive: false,
StartDate: now.Add(-1 * time.Hour),
EndDate: now.Add(1 * time.Hour),
},
expected: false,
},
{
name: "promotion not yet started",
promo: Promotion{
IsActive: true,
StartDate: now.Add(1 * time.Hour),
EndDate: now.Add(2 * time.Hour),
},
expected: false,
},
{
name: "promotion already ended",
promo: Promotion{
IsActive: true,
StartDate: now.Add(-2 * time.Hour),
EndDate: now.Add(-1 * time.Hour),
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.promo.IsCurrentlyActive())
})
}
}
func TestGetDefaultFreeLimits(t *testing.T) {
limits := GetDefaultFreeLimits()
assert.Equal(t, TierFree, limits.Tier)
require.NotNil(t, limits.PropertiesLimit)
require.NotNil(t, limits.TasksLimit)
require.NotNil(t, limits.ContractorsLimit)
require.NotNil(t, limits.DocumentsLimit)
assert.Equal(t, 1, *limits.PropertiesLimit)
assert.Equal(t, 10, *limits.TasksLimit)
assert.Equal(t, 0, *limits.ContractorsLimit)
assert.Equal(t, 0, *limits.DocumentsLimit)
}
func TestGetDefaultProLimits(t *testing.T) {
limits := GetDefaultProLimits()
assert.Equal(t, TierPro, limits.Tier)
assert.Nil(t, limits.PropertiesLimit)
assert.Nil(t, limits.TasksLimit)
assert.Nil(t, limits.ContractorsLimit)
assert.Nil(t, limits.DocumentsLimit)
}
// === ConfirmationCode additional tests ===
func TestConfirmationCode_TableName(t *testing.T) {
cc := ConfirmationCode{}
assert.Equal(t, "user_confirmationcode", cc.TableName())
}
// === PasswordResetCode additional tests ===
func TestPasswordResetCode_TableName(t *testing.T) {
prc := PasswordResetCode{}
assert.Equal(t, "user_passwordresetcode", prc.TableName())
}
// === Social Auth TableName tests ===
func TestAppleSocialAuth_TableName(t *testing.T) {
a := AppleSocialAuth{}
assert.Equal(t, "user_applesocialauth", a.TableName())
}
func TestGoogleSocialAuth_TableName(t *testing.T) {
g := GoogleSocialAuth{}
assert.Equal(t, "user_googlesocialauth", g.TableName())
}
// === BaseModel tests ===
func TestBaseModel_BeforeCreate(t *testing.T) {
b := &BaseModel{}
err := b.BeforeCreate(nil)
require.NoError(t, err)
assert.False(t, b.CreatedAt.IsZero())
assert.False(t, b.UpdatedAt.IsZero())
}
func TestBaseModel_BeforeCreate_PreservesExisting(t *testing.T) {
existingTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
b := &BaseModel{
CreatedAt: existingTime,
UpdatedAt: existingTime,
}
err := b.BeforeCreate(nil)
require.NoError(t, err)
assert.Equal(t, existingTime, b.CreatedAt)
assert.Equal(t, existingTime, b.UpdatedAt)
}
func TestBaseModel_BeforeUpdate(t *testing.T) {
oldTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
b := &BaseModel{
UpdatedAt: oldTime,
}
err := b.BeforeUpdate(nil)
require.NoError(t, err)
assert.True(t, b.UpdatedAt.After(oldTime))
}

View File

@@ -11,10 +11,10 @@ import (
func TestUser_SetPassword(t *testing.T) {
user := &User{}
err := user.SetPassword("testpassword123")
err := user.SetPassword("testPassword123")
require.NoError(t, err)
assert.NotEmpty(t, user.Password)
assert.NotEqual(t, "testpassword123", user.Password) // Should be hashed
assert.NotEqual(t, "testPassword123", user.Password) // Should be hashed
}
func TestUser_CheckPassword(t *testing.T) {

View File

@@ -0,0 +1,233 @@
package monitoring
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLogFilters_GetLimit(t *testing.T) {
tests := []struct {
name string
limit int
expected int
}{
{"default for zero", 0, 100},
{"default for negative", -5, 100},
{"capped at 1000", 2000, 1000},
{"exactly 1000", 1000, 1000},
{"normal value", 50, 50},
{"minimum valid", 1, 1},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
f := LogFilters{Limit: tc.limit}
assert.Equal(t, tc.expected, f.GetLimit())
})
}
}
func TestHTTPStatsCollector_Record_And_GetStats(t *testing.T) {
c := NewHTTPStatsCollector()
c.Record("GET /api/tasks/", 100*time.Millisecond, 200)
c.Record("GET /api/tasks/", 200*time.Millisecond, 200)
c.Record("POST /api/tasks/", 50*time.Millisecond, 201)
c.Record("GET /api/tasks/", 300*time.Millisecond, 500)
stats := c.GetStats()
assert.Equal(t, int64(4), stats.RequestsTotal)
assert.True(t, stats.RequestsPerMinute > 0)
// Check endpoint stats
taskStats, ok := stats.ByEndpoint["GET /api/tasks/"]
assert.True(t, ok)
assert.Equal(t, int64(3), taskStats.Count)
assert.True(t, taskStats.AvgLatencyMs > 0)
postStats, ok := stats.ByEndpoint["POST /api/tasks/"]
assert.True(t, ok)
assert.Equal(t, int64(1), postStats.Count)
// Check status codes
assert.Equal(t, int64(2), stats.ByStatusCode[200])
assert.Equal(t, int64(1), stats.ByStatusCode[201])
assert.Equal(t, int64(1), stats.ByStatusCode[500])
// Error rate should include 500 status
assert.True(t, stats.ErrorRate > 0)
}
func TestHTTPStatsCollector_Reset(t *testing.T) {
c := NewHTTPStatsCollector()
c.Record("GET /api/tasks/", 100*time.Millisecond, 200)
c.Reset()
stats := c.GetStats()
assert.Equal(t, int64(0), stats.RequestsTotal)
assert.Empty(t, stats.ByEndpoint)
assert.Empty(t, stats.ByStatusCode)
}
func TestHTTPStatsCollector_ErrorRate(t *testing.T) {
c := NewHTTPStatsCollector()
c.Record("GET /api/tasks/", 10*time.Millisecond, 200)
c.Record("GET /api/tasks/", 10*time.Millisecond, 400)
c.Record("GET /api/tasks/", 10*time.Millisecond, 500)
stats := c.GetStats()
ep := stats.ByEndpoint["GET /api/tasks/"]
// 2 out of 3 are errors (400 and 500)
assert.InDelta(t, 2.0/3.0, ep.ErrorRate, 0.001)
}
func TestHTTPStatsCollector_P95(t *testing.T) {
c := NewHTTPStatsCollector()
// Record 100 requests with increasing latencies
for i := 1; i <= 100; i++ {
c.Record("GET /api/test/", time.Duration(i)*time.Millisecond, 200)
}
stats := c.GetStats()
ep := stats.ByEndpoint["GET /api/test/"]
// P95 should be around 95ms
assert.True(t, ep.P95LatencyMs >= 90, "P95 should be >= 90ms, got %f", ep.P95LatencyMs)
}
func TestHTTPStatsCollector_EmptyStats(t *testing.T) {
c := NewHTTPStatsCollector()
stats := c.GetStats()
assert.Equal(t, int64(0), stats.RequestsTotal)
assert.Equal(t, float64(0), stats.AvgLatencyMs)
assert.Equal(t, float64(0), stats.ErrorRate)
assert.Equal(t, float64(0), stats.RequestsPerMinute)
}
func TestHTTPStatsCollector_EndpointOverflow(t *testing.T) {
c := NewHTTPStatsCollector()
// Fill up to maxEndpoints unique endpoints
for i := 0; i < maxEndpoints+10; i++ {
endpoint := "GET /api/test/" + string(rune('A'+i%26)) + string(rune('0'+i/26))
c.Record(endpoint, 10*time.Millisecond, 200)
}
stats := c.GetStats()
// Should have at most maxEndpoints + 1 (the OTHER bucket)
assert.LessOrEqual(t, len(stats.ByEndpoint), maxEndpoints+1)
}
func TestWSMessageConstants(t *testing.T) {
assert.Equal(t, "log", WSMessageTypeLog)
assert.Equal(t, "stats", WSMessageTypeStats)
}
func TestRedisLogWriter_Write_Disabled(t *testing.T) {
// Create a writer with a nil buffer -- won't actually push to Redis
// but we can test the enabled/disabled logic
w := &RedisLogWriter{
process: "api",
ch: make(chan LogEntry, writerChannelSize),
done: make(chan struct{}),
}
w.enabled.Store(false)
// Start drain loop (reads from channel)
go func() {
defer close(w.done)
for range w.ch {
}
}()
n, err := w.Write([]byte(`{"level":"info","message":"test"}`))
assert.NoError(t, err)
assert.Greater(t, n, 0)
// Channel should be empty since writer is disabled
assert.Equal(t, 0, len(w.ch))
close(w.ch)
<-w.done
}
func TestRedisLogWriter_Write_Enabled(t *testing.T) {
w := &RedisLogWriter{
process: "api",
ch: make(chan LogEntry, writerChannelSize),
done: make(chan struct{}),
}
w.enabled.Store(true)
go func() {
defer close(w.done)
for range w.ch {
}
}()
n, err := w.Write([]byte(`{"level":"info","message":"hello","caller":"main.go:10"}`))
assert.NoError(t, err)
assert.Greater(t, n, 0)
// Give the goroutine a moment, then close
close(w.ch)
<-w.done
}
func TestRedisLogWriter_Write_InvalidJSON(t *testing.T) {
w := &RedisLogWriter{
process: "api",
ch: make(chan LogEntry, writerChannelSize),
done: make(chan struct{}),
}
w.enabled.Store(true)
go func() {
defer close(w.done)
for range w.ch {
}
}()
// Non-JSON input should be silently skipped
n, err := w.Write([]byte("not json at all"))
assert.NoError(t, err)
assert.Greater(t, n, 0)
assert.Equal(t, 0, len(w.ch))
close(w.ch)
<-w.done
}
func TestRedisLogWriter_SetEnabled_IsEnabled(t *testing.T) {
w := &RedisLogWriter{
process: "api",
ch: make(chan LogEntry, 1),
done: make(chan struct{}),
}
w.enabled.Store(true)
assert.True(t, w.IsEnabled())
w.SetEnabled(false)
assert.False(t, w.IsEnabled())
w.SetEnabled(true)
assert.True(t, w.IsEnabled())
}
func TestCollector_Stop_MultipleCallsSafe(t *testing.T) {
c := &Collector{
process: "api",
startTime: time.Now(),
stopChan: make(chan struct{}),
}
// Should not panic on multiple calls
c.Stop()
c.Stop()
c.Stop()
}

View File

@@ -0,0 +1,359 @@
package push
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// === truncateToken tests ===
func TestTruncateToken_LongToken(t *testing.T) {
token := "abcdefghijklmnopqrstuvwxyz1234567890"
result := truncateToken(token)
assert.Equal(t, "abcdefgh...", result)
}
func TestTruncateToken_ShortToken(t *testing.T) {
token := "abc"
result := truncateToken(token)
assert.Equal(t, "abc", result)
}
func TestTruncateToken_ExactlyEightChars(t *testing.T) {
token := "12345678"
result := truncateToken(token)
assert.Equal(t, "12345678", result)
}
func TestTruncateToken_NineChars(t *testing.T) {
token := "123456789"
result := truncateToken(token)
assert.Equal(t, "12345678...", result)
}
func TestTruncateToken_Empty(t *testing.T) {
result := truncateToken("")
assert.Equal(t, "", result)
}
// === Client tests ===
func TestClient_SendToIOS_Disabled(t *testing.T) {
client := &Client{
enabled: false,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToIOS(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.NoError(t, err) // Returns nil when disabled
}
func TestClient_SendToIOS_NilAPNs(t *testing.T) {
client := &Client{
enabled: true,
apns: nil, // Not initialized
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToIOS(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.NoError(t, err) // Returns nil when not initialized
}
func TestClient_SendToIOS_CircuitBreakerOpen(t *testing.T) {
breaker := NewCircuitBreaker("apns", WithFailureThreshold(1))
breaker.RecordFailure() // Open the circuit
client := &Client{
enabled: true,
apns: &APNsClient{}, // Non-nil so we pass the nil check
apnsBreaker: breaker,
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToIOS(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.ErrorIs(t, err, ErrCircuitOpen)
}
func TestClient_SendToAndroid_Disabled(t *testing.T) {
client := &Client{
enabled: false,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestClient_SendToAndroid_NilFCM(t *testing.T) {
client := &Client{
enabled: true,
fcm: nil,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestClient_SendToAndroid_CircuitBreakerOpen(t *testing.T) {
breaker := NewCircuitBreaker("fcm", WithFailureThreshold(1))
breaker.RecordFailure()
client := &Client{
enabled: true,
fcm: &FCMClient{}, // Non-nil
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: breaker,
}
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.ErrorIs(t, err, ErrCircuitOpen)
}
func TestClient_SendToAndroid_Success(t *testing.T) {
server := serveFCMV1Success(t)
defer server.Close()
fcmClient := newTestFCMClient(server.URL)
client := &Client{
enabled: true,
fcm: fcmClient,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestClient_SendToAndroid_Failure_RecordsInBreaker(t *testing.T) {
server := serveFCMV1Error(t, http.StatusInternalServerError, "INTERNAL", "internal error")
defer server.Close()
fcmClient := newTestFCMClient(server.URL)
breaker := NewCircuitBreaker("fcm", WithFailureThreshold(3))
client := &Client{
enabled: true,
fcm: fcmClient,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: breaker,
}
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
assert.Error(t, err)
assert.Equal(t, 1, breaker.Counts())
}
func TestClient_SendToAll_Disabled(t *testing.T) {
client := &Client{
enabled: false,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToAll(context.Background(), []string{"ios-token"}, []string{"android-token"}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestClient_SendToAll_EmptyTokens(t *testing.T) {
server := serveFCMV1Success(t)
defer server.Close()
fcmClient := newTestFCMClient(server.URL)
client := &Client{
enabled: true,
fcm: fcmClient,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
// No tokens at all — should just return nil
err := client.SendToAll(context.Background(), []string{}, []string{}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestClient_SendToAll_AndroidOnly(t *testing.T) {
server := serveFCMV1Success(t)
defer server.Close()
fcmClient := newTestFCMClient(server.URL)
client := &Client{
enabled: true,
fcm: fcmClient,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendToAll(context.Background(), []string{}, []string{"android-token"}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestClient_IsIOSEnabled(t *testing.T) {
clientWithAPNS := &Client{apns: &APNsClient{}}
clientWithoutAPNS := &Client{apns: nil}
assert.True(t, clientWithAPNS.IsIOSEnabled())
assert.False(t, clientWithoutAPNS.IsIOSEnabled())
}
func TestClient_IsAndroidEnabled(t *testing.T) {
clientWithFCM := &Client{fcm: &FCMClient{}}
clientWithoutFCM := &Client{fcm: nil}
assert.True(t, clientWithFCM.IsAndroidEnabled())
assert.False(t, clientWithoutFCM.IsAndroidEnabled())
}
func TestClient_HealthCheck(t *testing.T) {
client := &Client{
apns: nil,
fcm: nil,
}
err := client.HealthCheck(context.Background())
assert.NoError(t, err)
client.fcm = &FCMClient{}
err = client.HealthCheck(context.Background())
assert.NoError(t, err)
}
func TestClient_SendActionableNotification_Disabled(t *testing.T) {
client := &Client{
enabled: false,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendActionableNotification(context.Background(), []string{"ios"}, []string{"android"}, "Title", "Body", nil, "TASK_DUE")
assert.NoError(t, err)
}
func TestClient_SendActionableNotification_NilAPNs(t *testing.T) {
server := serveFCMV1Success(t)
defer server.Close()
fcmClient := newTestFCMClient(server.URL)
client := &Client{
enabled: true,
apns: nil,
fcm: fcmClient,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
// Should skip iOS and send Android
err := client.SendActionableNotification(context.Background(), []string{"ios"}, []string{"android"}, "Title", "Body", nil, "TASK_DUE")
assert.NoError(t, err)
}
func TestClient_SendActionableNotification_APNsBreakerOpen(t *testing.T) {
server := serveFCMV1Success(t)
defer server.Close()
fcmClient := newTestFCMClient(server.URL)
apnsBreaker := NewCircuitBreaker("apns", WithFailureThreshold(1))
apnsBreaker.RecordFailure()
client := &Client{
enabled: true,
apns: &APNsClient{},
fcm: fcmClient,
apnsBreaker: apnsBreaker,
fcmBreaker: NewCircuitBreaker("fcm"),
}
err := client.SendActionableNotification(context.Background(), []string{"ios"}, []string{"android"}, "Title", "Body", nil, "TASK_DUE")
// Should return ErrCircuitOpen because that was the lastErr set
assert.ErrorIs(t, err, ErrCircuitOpen)
}
// === FCM additional tests ===
func TestFCMV1Send_WithDataPayload(t *testing.T) {
var receivedData map[string]string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req fcmV1Request
_ = json.NewDecoder(r.Body).Decode(&req)
receivedData = req.Message.Data
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(fcmV1Response{Name: "projects/test/messages/0:12345"})
}))
defer server.Close()
client := newTestFCMClient(server.URL)
data := map[string]string{
"task_id": "42",
"action": "complete",
"deep_link": "/tasks/42",
}
err := client.Send(context.Background(), []string{"token"}, "Title", "Body", data)
require.NoError(t, err)
assert.Equal(t, "42", receivedData["task_id"])
assert.Equal(t, "complete", receivedData["action"])
assert.Equal(t, "/tasks/42", receivedData["deep_link"])
}
func TestFCMV1Send_ContextCancelled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The server just hangs — context cancellation should cause the request to fail
select {}
}))
defer server.Close()
client := newTestFCMClient(server.URL)
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
err := client.Send(ctx, []string{"token"}, "Title", "Body", nil)
assert.Error(t, err)
}
func TestFCMSendError_ErrorFormatting(t *testing.T) {
// Short token (no truncation)
shortErr := &FCMSendError{
Token: "abc",
StatusCode: 500,
ErrorCode: FCMErrInternal,
Message: "server error",
}
assert.Contains(t, shortErr.Error(), "abc")
assert.Contains(t, shortErr.Error(), "500")
assert.Contains(t, shortErr.Error(), "INTERNAL")
assert.Contains(t, shortErr.Error(), "server error")
}
func TestParseFCMV1Error_MalformedJSON(t *testing.T) {
result := parseFCMV1Error("token123", 500, []byte("not json"))
assert.Equal(t, 500, result.StatusCode)
assert.Contains(t, result.Message, "unparseable error response")
}
func TestParseFCMV1Error_ValidJSON(t *testing.T) {
body := `{"error":{"code":404,"message":"not found","status":"NOT_FOUND"}}`
result := parseFCMV1Error("token123", 404, []byte(body))
assert.Equal(t, 404, result.StatusCode)
assert.Equal(t, FCMErrorCode("NOT_FOUND"), result.ErrorCode)
assert.Equal(t, "not found", result.Message)
}
// === Platform constants ===
func TestPlatformConstants(t *testing.T) {
assert.Equal(t, "ios", PlatformIOS)
assert.Equal(t, "android", PlatformAndroid)
}

View File

@@ -0,0 +1,205 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestAdminRepository_Create(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{
Email: "admin@test.com",
FirstName: "Test",
LastName: "Admin",
Role: models.AdminRoleAdmin,
IsActive: true,
}
require.NoError(t, admin.SetPassword("Password123"))
err := repo.Create(admin)
require.NoError(t, err)
assert.NotZero(t, admin.ID)
}
func TestAdminRepository_Create_Duplicate(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin1 := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin1.SetPassword("Password123"))
err := repo.Create(admin1)
require.NoError(t, err)
admin2 := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin2.SetPassword("Password123"))
err = repo.Create(admin2)
assert.ErrorIs(t, err, ErrAdminExists)
}
func TestAdminRepository_FindByID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
found, err := repo.FindByID(admin.ID)
require.NoError(t, err)
assert.Equal(t, "admin@test.com", found.Email)
}
func TestAdminRepository_FindByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
_, err := repo.FindByID(9999)
assert.ErrorIs(t, err, ErrAdminNotFound)
}
func TestAdminRepository_FindByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
found, err := repo.FindByEmail("admin@test.com")
require.NoError(t, err)
assert.Equal(t, admin.ID, found.ID)
}
func TestAdminRepository_FindByEmail_CaseInsensitive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "Admin@Test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
found, err := repo.FindByEmail("admin@test.com")
require.NoError(t, err)
assert.Equal(t, admin.ID, found.ID)
}
func TestAdminRepository_FindByEmail_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
_, err := repo.FindByEmail("nonexistent@test.com")
assert.ErrorIs(t, err, ErrAdminNotFound)
}
func TestAdminRepository_Update(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "admin@test.com", FirstName: "Old", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
admin.FirstName = "New"
err := repo.Update(admin)
require.NoError(t, err)
found, err := repo.FindByID(admin.ID)
require.NoError(t, err)
assert.Equal(t, "New", found.FirstName)
}
func TestAdminRepository_Delete(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
err := repo.Delete(admin.ID)
require.NoError(t, err)
_, err = repo.FindByID(admin.ID)
assert.ErrorIs(t, err, ErrAdminNotFound)
}
func TestAdminRepository_UpdateLastLogin(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
assert.Nil(t, admin.LastLogin)
err := repo.UpdateLastLogin(admin.ID)
require.NoError(t, err)
found, err := repo.FindByID(admin.ID)
require.NoError(t, err)
assert.NotNil(t, found.LastLogin)
}
func TestAdminRepository_List(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
for i := 0; i < 5; i++ {
admin := &models.AdminUser{
Email: "admin" + string(rune('0'+i)) + "@test.com",
Role: models.AdminRoleAdmin,
}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
}
// Page 1
admins, total, err := repo.List(1, 3)
require.NoError(t, err)
assert.Equal(t, int64(5), total)
assert.Len(t, admins, 3)
// Page 2
admins, total, err = repo.List(2, 3)
require.NoError(t, err)
assert.Equal(t, int64(5), total)
assert.Len(t, admins, 2)
}
func TestAdminRepository_ExistsByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
exists, err := repo.ExistsByEmail("admin@test.com")
require.NoError(t, err)
assert.True(t, exists)
exists, err = repo.ExistsByEmail("nonexistent@test.com")
require.NoError(t, err)
assert.False(t, exists)
}
func TestAdminRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewAdminRepository(db)
admin := &models.AdminUser{Email: "Admin@Test.com", Role: models.AdminRoleAdmin}
require.NoError(t, admin.SetPassword("Password123"))
require.NoError(t, repo.Create(admin))
exists, err := repo.ExistsByEmail("admin@test.com")
require.NoError(t, err)
assert.True(t, exists)
}

View File

@@ -0,0 +1,356 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestContractorRepository_Create(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := &models.Contractor{
ResidenceID: &residence.ID,
CreatedByID: user.ID,
Name: "Mike's Plumbing",
Company: "Mike's Plumbing Co.",
Phone: "+1-555-1234",
Email: "mike@plumbing.com",
IsActive: true,
}
err := repo.Create(contractor)
require.NoError(t, err)
assert.NotZero(t, contractor.ID)
}
func TestContractorRepository_FindByID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
found, err := repo.FindByID(contractor.ID)
require.NoError(t, err)
assert.Equal(t, contractor.ID, found.ID)
assert.Equal(t, "Test Contractor", found.Name)
}
func TestContractorRepository_FindByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
_, err := repo.FindByID(9999)
assert.Error(t, err)
}
func TestContractorRepository_FindByID_InactiveNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Inactive Contractor")
// Soft-delete
db.Model(&models.Contractor{}).Where("id = ?", contractor.ID).Update("is_active", false)
_, err := repo.FindByID(contractor.ID)
assert.Error(t, err)
}
func TestContractorRepository_FindByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
otherResidence := testutil.CreateTestResidence(t, db, user.ID, "Other House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
testutil.CreateTestContractor(t, db, otherResidence.ID, user.ID, "Other Contractor")
contractors, err := repo.FindByResidence(residence.ID)
require.NoError(t, err)
assert.Len(t, contractors, 2)
}
func TestContractorRepository_FindByResidence_ExcludesInactive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
active := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Active Contractor")
inactive := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Inactive Contractor")
db.Model(&models.Contractor{}).Where("id = ?", inactive.ID).Update("is_active", false)
contractors, err := repo.FindByResidence(residence.ID)
require.NoError(t, err)
assert.Len(t, contractors, 1)
assert.Equal(t, active.ID, contractors[0].ID)
}
func TestContractorRepository_FindByUser_WithResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Residence Contractor")
contractors, err := repo.FindByUser(user.ID, []uint{residence.ID})
require.NoError(t, err)
assert.Len(t, contractors, 1)
assert.Equal(t, "Residence Contractor", contractors[0].Name)
}
func TestContractorRepository_FindByUser_NoResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Personal contractor with no residence
personal := &models.Contractor{
ResidenceID: nil,
CreatedByID: user.ID,
Name: "Personal Contractor",
IsActive: true,
}
err := db.Create(personal).Error
require.NoError(t, err)
contractors, err := repo.FindByUser(user.ID, []uint{})
require.NoError(t, err)
assert.Len(t, contractors, 1)
assert.Equal(t, "Personal Contractor", contractors[0].Name)
}
func TestContractorRepository_Update(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Original Name")
contractor.Name = "Updated Name"
contractor.Phone = "+1-555-9999"
err := repo.Update(contractor)
require.NoError(t, err)
found, err := repo.FindByID(contractor.ID)
require.NoError(t, err)
assert.Equal(t, "Updated Name", found.Name)
assert.Equal(t, "+1-555-9999", found.Phone)
}
func TestContractorRepository_Delete(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
err := repo.Delete(contractor.ID)
require.NoError(t, err)
// Should not be found (soft delete)
_, err = repo.FindByID(contractor.ID)
assert.Error(t, err)
}
func TestContractorRepository_CountByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
count, err := repo.CountByResidence(residence.ID)
require.NoError(t, err)
assert.Equal(t, int64(2), count)
}
func TestContractorRepository_CountByResidenceIDs(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestContractor(t, db, r1.ID, user.ID, "Contractor A")
testutil.CreateTestContractor(t, db, r2.ID, user.ID, "Contractor B")
testutil.CreateTestContractor(t, db, r2.ID, user.ID, "Contractor C")
count, err := repo.CountByResidenceIDs([]uint{r1.ID, r2.ID})
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
func TestContractorRepository_CountByResidenceIDs_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
count, err := repo.CountByResidenceIDs([]uint{})
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
func TestContractorRepository_SetSpecialties(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Skilled Contractor")
// Get seeded specialties
var specialties []models.ContractorSpecialty
err := db.Find(&specialties).Error
require.NoError(t, err)
require.GreaterOrEqual(t, len(specialties), 2)
// Set specialties
err = repo.SetSpecialties(contractor.ID, []uint{specialties[0].ID, specialties[1].ID})
require.NoError(t, err)
// Verify via FindByID
found, err := repo.FindByID(contractor.ID)
require.NoError(t, err)
assert.Len(t, found.Specialties, 2)
}
func TestContractorRepository_SetSpecialties_ClearsExisting(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Skilled Contractor")
var specialties []models.ContractorSpecialty
err := db.Find(&specialties).Error
require.NoError(t, err)
require.GreaterOrEqual(t, len(specialties), 2)
// Set initial specialties
err = repo.SetSpecialties(contractor.ID, []uint{specialties[0].ID, specialties[1].ID})
require.NoError(t, err)
// Clear all specialties
err = repo.SetSpecialties(contractor.ID, []uint{})
require.NoError(t, err)
found, err := repo.FindByID(contractor.ID)
require.NoError(t, err)
assert.Len(t, found.Specialties, 0)
}
func TestContractorRepository_GetAllSpecialties(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewContractorRepository(db)
specialties, err := repo.GetAllSpecialties()
require.NoError(t, err)
assert.GreaterOrEqual(t, len(specialties), 4) // Seeded 4 specialties
}
func TestContractorRepository_FindSpecialtyByID(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewContractorRepository(db)
var seeded models.ContractorSpecialty
err := db.First(&seeded).Error
require.NoError(t, err)
found, err := repo.FindSpecialtyByID(seeded.ID)
require.NoError(t, err)
assert.Equal(t, seeded.Name, found.Name)
}
func TestContractorRepository_FindSpecialtyByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
_, err := repo.FindSpecialtyByID(9999)
assert.Error(t, err)
}
func TestContractorRepository_GetTasksForContractor(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Task Contractor")
// Create tasks linked to contractor
task1 := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Task 1",
ContractorID: &contractor.ID,
Version: 1,
}
err := db.Create(task1).Error
require.NoError(t, err)
task2 := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Task 2",
ContractorID: &contractor.ID,
Version: 1,
}
err = db.Create(task2).Error
require.NoError(t, err)
tasks, err := repo.GetTasksForContractor(contractor.ID)
require.NoError(t, err)
assert.Len(t, tasks, 2)
}
func TestContractorRepository_FindByResidence_FavoritesFirst(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create non-favorite first (alphabetically first)
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Alpha Contractor")
// Create favorite
fav := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Zeta Contractor")
db.Model(&models.Contractor{}).Where("id = ?", fav.ID).Update("is_favorite", true)
contractors, err := repo.FindByResidence(residence.ID)
require.NoError(t, err)
assert.Len(t, contractors, 2)
// Favorite should be first
assert.Equal(t, fav.ID, contractors[0].ID)
}

View File

@@ -0,0 +1,384 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestDocumentRepository_Create(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "HVAC Warranty",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/hvac.pdf",
IsActive: true,
}
err := repo.Create(doc)
require.NoError(t, err)
assert.NotZero(t, doc.ID)
}
func TestDocumentRepository_FindByID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
found, err := repo.FindByID(doc.ID)
require.NoError(t, err)
assert.Equal(t, doc.ID, found.ID)
assert.Equal(t, "Test Doc", found.Title)
}
func TestDocumentRepository_FindByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
_, err := repo.FindByID(9999)
assert.Error(t, err)
}
func TestDocumentRepository_FindByID_InactiveNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
db.Model(&models.Document{}).Where("id = ?", doc.ID).Update("is_active", false)
_, err := repo.FindByID(doc.ID)
assert.Error(t, err)
}
func TestDocumentRepository_FindByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
otherResidence := testutil.CreateTestResidence(t, db, user.ID, "Other House")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc A")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc B")
testutil.CreateTestDocument(t, db, otherResidence.ID, user.ID, "Other Doc")
docs, err := repo.FindByResidence(residence.ID)
require.NoError(t, err)
assert.Len(t, docs, 2)
}
func TestDocumentRepository_FindByResidence_ExcludesInactive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
active := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
inactive := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
db.Model(&models.Document{}).Where("id = ?", inactive.ID).Update("is_active", false)
docs, err := repo.FindByResidence(residence.ID)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, active.ID, docs[0].ID)
}
func TestDocumentRepository_FindByUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestDocument(t, db, r1.ID, user.ID, "Doc 1")
testutil.CreateTestDocument(t, db, r2.ID, user.ID, "Doc 2")
docs, err := repo.FindByUser([]uint{r1.ID, r2.ID})
require.NoError(t, err)
assert.Len(t, docs, 2)
}
func TestDocumentRepository_Update(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original Title")
doc.Title = "Updated Title"
doc.Description = "Updated description"
err := repo.Update(doc)
require.NoError(t, err)
found, err := repo.FindByID(doc.ID)
require.NoError(t, err)
assert.Equal(t, "Updated Title", found.Title)
assert.Equal(t, "Updated description", found.Description)
}
func TestDocumentRepository_Delete(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
err := repo.Delete(doc.ID)
require.NoError(t, err)
_, err = repo.FindByID(doc.ID)
assert.Error(t, err)
}
func TestDocumentRepository_ActivateDeactivate(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Toggle Doc")
// Deactivate
err := repo.Deactivate(doc.ID)
require.NoError(t, err)
_, err = repo.FindByID(doc.ID)
assert.Error(t, err) // Not found when inactive
// Activate
err = repo.Activate(doc.ID)
require.NoError(t, err)
found, err := repo.FindByID(doc.ID)
require.NoError(t, err)
assert.Equal(t, doc.ID, found.ID)
}
func TestDocumentRepository_CountByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 3")
count, err := repo.CountByResidence(residence.ID)
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
func TestDocumentRepository_CountByResidenceIDs(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestDocument(t, db, r1.ID, user.ID, "Doc A")
testutil.CreateTestDocument(t, db, r2.ID, user.ID, "Doc B")
testutil.CreateTestDocument(t, db, r2.ID, user.ID, "Doc C")
count, err := repo.CountByResidenceIDs([]uint{r1.ID, r2.ID})
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
func TestDocumentRepository_CountByResidenceIDs_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
count, err := repo.CountByResidenceIDs([]uint{})
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
func TestDocumentRepository_DocumentImages(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc with Images")
// Create image
img := &models.DocumentImage{
DocumentID: doc.ID,
ImageURL: "https://example.com/img1.jpg",
Caption: "Front page",
}
err := repo.CreateDocumentImage(img)
require.NoError(t, err)
assert.NotZero(t, img.ID)
// Find image by ID
found, err := repo.FindImageByID(img.ID)
require.NoError(t, err)
assert.Equal(t, "https://example.com/img1.jpg", found.ImageURL)
assert.Equal(t, "Front page", found.Caption)
// Delete single image
err = repo.DeleteDocumentImage(img.ID)
require.NoError(t, err)
_, err = repo.FindImageByID(img.ID)
assert.Error(t, err)
}
func TestDocumentRepository_DeleteDocumentImages(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc with Images")
// Create multiple images
for i := 0; i < 3; i++ {
img := &models.DocumentImage{
DocumentID: doc.ID,
ImageURL: "https://example.com/img.jpg",
}
err := repo.CreateDocumentImage(img)
require.NoError(t, err)
}
// Delete all images for document
err := repo.DeleteDocumentImages(doc.ID)
require.NoError(t, err)
// Verify all deleted
var count int64
db.Model(&models.DocumentImage{}).Where("document_id = ?", doc.ID).Count(&count)
assert.Equal(t, int64(0), count)
}
func TestDocumentRepository_FindByIDIncludingInactive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Deactivated Doc")
// Deactivate
db.Model(&models.Document{}).Where("id = ?", doc.ID).Update("is_active", false)
// FindByID should not find it
_, err := repo.FindByID(doc.ID)
assert.Error(t, err)
// FindByIDIncludingInactive should find it
var found models.Document
err = repo.FindByIDIncludingInactive(doc.ID, &found)
require.NoError(t, err)
assert.Equal(t, doc.ID, found.ID)
}
func TestDocumentRepository_FindByUserFiltered_ByType(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create general document
generalDoc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "General Doc",
DocumentType: models.DocumentTypeGeneral,
FileURL: "https://example.com/gen.pdf",
IsActive: true,
}
err := db.Create(generalDoc).Error
require.NoError(t, err)
// Create warranty document
warrantyDoc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Warranty Doc",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/warranty.pdf",
IsActive: true,
}
err = db.Create(warrantyDoc).Error
require.NoError(t, err)
// Filter by warranty type
filter := &DocumentFilter{
DocumentType: string(models.DocumentTypeWarranty),
}
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "Warranty Doc", docs[0].Title)
}
func TestDocumentRepository_FindByUserFiltered_NilFilter(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, nil)
require.NoError(t, err)
assert.Len(t, docs, 2)
}
func TestDocumentRepository_FindWarranties(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// General doc (should not be returned)
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
// Warranty doc
warranty := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Warranty",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/warranty.pdf",
IsActive: true,
}
err := db.Create(warranty).Error
require.NoError(t, err)
docs, err := repo.FindWarranties([]uint{residence.ID})
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "Warranty", docs[0].Title)
}

View File

@@ -0,0 +1,207 @@
package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestDocumentRepository_FindExpiringWarranties(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
now := time.Now().UTC()
expiringSoon := now.AddDate(0, 0, 15) // 15 days from now
expiringLater := now.AddDate(0, 0, 60) // 60 days from now
alreadyExpired := now.AddDate(0, 0, -5) // 5 days ago
// Warranty expiring in 15 days (within 30 day threshold)
w1 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Expiring Warranty",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/w1.pdf",
ExpiryDate: &expiringSoon,
IsActive: true,
}
require.NoError(t, db.Create(w1).Error)
// Warranty expiring in 60 days (outside 30 day threshold)
w2 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Later Warranty",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/w2.pdf",
ExpiryDate: &expiringLater,
IsActive: true,
}
require.NoError(t, db.Create(w2).Error)
// Already expired warranty
w3 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Expired Warranty",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/w3.pdf",
ExpiryDate: &alreadyExpired,
IsActive: true,
}
require.NoError(t, db.Create(w3).Error)
// General document (not warranty)
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
docs, err := repo.FindExpiringWarranties([]uint{residence.ID}, 30)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "Expiring Warranty", docs[0].Title)
}
func TestDocumentRepository_FindByUserFiltered_Search(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite does not support ILIKE")
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
d1 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "HVAC Warranty Certificate",
Description: "For the main HVAC unit",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/hvac.pdf",
IsActive: true,
}
require.NoError(t, db.Create(d1).Error)
d2 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Plumbing Receipt",
Description: "Bathroom plumbing fix",
DocumentType: models.DocumentTypeReceipt,
FileURL: "https://example.com/plumbing.pdf",
IsActive: true,
}
require.NoError(t, db.Create(d2).Error)
// Search by title
filter := &DocumentFilter{Search: "HVAC"}
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "HVAC Warranty Certificate", docs[0].Title)
// Search by description
filter = &DocumentFilter{Search: "bathroom"}
docs, err = repo.FindByUserFiltered([]uint{residence.ID}, filter)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "Plumbing Receipt", docs[0].Title)
}
func TestDocumentRepository_FindByUserFiltered_IsActiveOverride(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
active := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
inactive := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
db.Model(&models.Document{}).Where("id = ?", inactive.ID).Update("is_active", false)
// With IsActive = false, should get only inactive
isActiveFalse := false
filter := &DocumentFilter{IsActive: &isActiveFalse}
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, inactive.ID, docs[0].ID)
// With IsActive = true, should get only active
isActiveTrue := true
filter = &DocumentFilter{IsActive: &isActiveTrue}
docs, err = repo.FindByUserFiltered([]uint{residence.ID}, filter)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, active.ID, docs[0].ID)
}
func TestDocumentRepository_FindByUserFiltered_ExpiringSoon(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
now := time.Now().UTC()
expiringSoon := now.AddDate(0, 0, 10)
notExpiring := now.AddDate(0, 0, 90)
d1 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Expiring Doc",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/exp.pdf",
ExpiryDate: &expiringSoon,
IsActive: true,
}
require.NoError(t, db.Create(d1).Error)
d2 := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Not Expiring Doc",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/noexp.pdf",
ExpiryDate: &notExpiring,
IsActive: true,
}
require.NoError(t, db.Create(d2).Error)
days := 30
filter := &DocumentFilter{ExpiringSoon: &days}
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
require.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "Expiring Doc", docs[0].Title)
}
func TestDocumentRepository_FindImageByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
_, err := repo.FindImageByID(9999)
assert.Error(t, err)
}
func TestDocumentRepository_FindByUser_ExcludesInactive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
inactive := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
db.Model(&models.Document{}).Where("id = ?", inactive.ID).Update("is_active", false)
docs, err := repo.FindByUser([]uint{residence.ID})
require.NoError(t, err)
assert.Len(t, docs, 1)
}

View File

@@ -0,0 +1,510 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestNotificationRepository_Create(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
notification := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Task Due Soon",
Body: "Your task is due tomorrow",
}
err := repo.Create(notification)
require.NoError(t, err)
assert.NotZero(t, notification.ID)
}
func TestNotificationRepository_FindByID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
notification := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskOverdue,
Title: "Task Overdue",
Body: "Your task is overdue",
}
err := db.Create(notification).Error
require.NoError(t, err)
found, err := repo.FindByID(notification.ID)
require.NoError(t, err)
assert.Equal(t, "Task Overdue", found.Title)
assert.Equal(t, user.ID, found.UserID)
}
func TestNotificationRepository_FindByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
_, err := repo.FindByID(9999)
assert.Error(t, err)
}
func TestNotificationRepository_FindByUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
// Create notifications for user
for i := 0; i < 5; i++ {
n := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Notification",
Body: "Body",
}
err := db.Create(n).Error
require.NoError(t, err)
}
// Create notification for other user
otherN := &models.Notification{
UserID: otherUser.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Other",
Body: "Body",
}
err := db.Create(otherN).Error
require.NoError(t, err)
// Find all for user
notifications, err := repo.FindByUser(user.ID, 0, 0)
require.NoError(t, err)
assert.Len(t, notifications, 5)
// Find with limit
notifications, err = repo.FindByUser(user.ID, 3, 0)
require.NoError(t, err)
assert.Len(t, notifications, 3)
// Find with offset
notifications, err = repo.FindByUser(user.ID, 3, 3)
require.NoError(t, err)
assert.Len(t, notifications, 2) // Only 2 remaining
}
func TestNotificationRepository_MarkAsRead(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
notification := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskCompleted,
Title: "Task Completed",
Body: "Your task was completed",
Read: false,
}
err := db.Create(notification).Error
require.NoError(t, err)
err = repo.MarkAsRead(notification.ID)
require.NoError(t, err)
found, err := repo.FindByID(notification.ID)
require.NoError(t, err)
assert.True(t, found.Read)
assert.NotNil(t, found.ReadAt)
}
func TestNotificationRepository_MarkAllAsRead(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create multiple unread notifications
for i := 0; i < 3; i++ {
n := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Unread",
Body: "Body",
Read: false,
}
err := db.Create(n).Error
require.NoError(t, err)
}
err := repo.MarkAllAsRead(user.ID)
require.NoError(t, err)
// Verify all are read
count, err := repo.CountUnread(user.ID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
func TestNotificationRepository_CountUnread(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create 3 unread and 2 read notifications
for i := 0; i < 3; i++ {
n := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Unread",
Body: "Body",
Read: false,
}
err := db.Create(n).Error
require.NoError(t, err)
}
for i := 0; i < 2; i++ {
n := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskCompleted,
Title: "Read",
Body: "Body",
Read: true,
}
err := db.Create(n).Error
require.NoError(t, err)
}
count, err := repo.CountUnread(user.ID)
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
func TestNotificationRepository_MarkAsSent(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
notification := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "To Send",
Body: "Body",
Sent: false,
}
err := db.Create(notification).Error
require.NoError(t, err)
err = repo.MarkAsSent(notification.ID)
require.NoError(t, err)
found, err := repo.FindByID(notification.ID)
require.NoError(t, err)
assert.True(t, found.Sent)
assert.NotNil(t, found.SentAt)
}
func TestNotificationRepository_SetError(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
notification := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Error Notification",
Body: "Body",
}
err := db.Create(notification).Error
require.NoError(t, err)
err = repo.SetError(notification.ID, "failed to send: connection timeout")
require.NoError(t, err)
found, err := repo.FindByID(notification.ID)
require.NoError(t, err)
assert.Equal(t, "failed to send: connection timeout", found.ErrorMessage)
}
func TestNotificationRepository_GetPendingNotifications(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create pending (unsent) notifications
for i := 0; i < 5; i++ {
n := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Pending",
Body: "Body",
Sent: false,
}
err := db.Create(n).Error
require.NoError(t, err)
}
// Create sent notification
sent := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskCompleted,
Title: "Already Sent",
Body: "Body",
Sent: true,
}
err := db.Create(sent).Error
require.NoError(t, err)
pending, err := repo.GetPendingNotifications(10)
require.NoError(t, err)
assert.Len(t, pending, 5) // Only unsent
// Test with limit
pending, err = repo.GetPendingNotifications(3)
require.NoError(t, err)
assert.Len(t, pending, 3)
}
func TestNotificationRepository_APNSDevice_CRUD(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create APNS device
device := &models.APNSDevice{
Name: "iPhone 15",
Active: true,
UserID: &user.ID,
DeviceID: "device-uuid-123",
RegistrationID: "apns-token-abc123",
}
err := repo.CreateAPNSDevice(device)
require.NoError(t, err)
assert.NotZero(t, device.ID)
// Find by ID
found, err := repo.FindAPNSDeviceByID(device.ID)
require.NoError(t, err)
assert.Equal(t, "iPhone 15", found.Name)
assert.Equal(t, "apns-token-abc123", found.RegistrationID)
// Find by token
foundByToken, err := repo.FindAPNSDeviceByToken("apns-token-abc123")
require.NoError(t, err)
assert.Equal(t, device.ID, foundByToken.ID)
// Find by user
devices, err := repo.FindAPNSDevicesByUser(user.ID)
require.NoError(t, err)
assert.Len(t, devices, 1)
// Update
device.Name = "iPhone 16"
err = repo.UpdateAPNSDevice(device)
require.NoError(t, err)
found, err = repo.FindAPNSDeviceByID(device.ID)
require.NoError(t, err)
assert.Equal(t, "iPhone 16", found.Name)
// Deactivate
err = repo.DeactivateAPNSDevice(device.ID)
require.NoError(t, err)
// Should not appear in active devices
devices, err = repo.FindAPNSDevicesByUser(user.ID)
require.NoError(t, err)
assert.Len(t, devices, 0)
// Delete
err = repo.DeleteAPNSDevice(device.ID)
require.NoError(t, err)
_, err = repo.FindAPNSDeviceByID(device.ID)
assert.Error(t, err)
}
func TestNotificationRepository_GCMDevice_CRUD(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create GCM device
device := &models.GCMDevice{
Name: "Pixel 8",
Active: true,
UserID: &user.ID,
DeviceID: "android-device-uuid",
RegistrationID: "fcm-token-xyz789",
CloudMessageType: "FCM",
}
err := repo.CreateGCMDevice(device)
require.NoError(t, err)
assert.NotZero(t, device.ID)
// Find by ID
found, err := repo.FindGCMDeviceByID(device.ID)
require.NoError(t, err)
assert.Equal(t, "Pixel 8", found.Name)
assert.Equal(t, "fcm-token-xyz789", found.RegistrationID)
// Find by token
foundByToken, err := repo.FindGCMDeviceByToken("fcm-token-xyz789")
require.NoError(t, err)
assert.Equal(t, device.ID, foundByToken.ID)
// Find by user
devices, err := repo.FindGCMDevicesByUser(user.ID)
require.NoError(t, err)
assert.Len(t, devices, 1)
// Update
device.Name = "Pixel 9"
err = repo.UpdateGCMDevice(device)
require.NoError(t, err)
found, err = repo.FindGCMDeviceByID(device.ID)
require.NoError(t, err)
assert.Equal(t, "Pixel 9", found.Name)
// Deactivate
err = repo.DeactivateGCMDevice(device.ID)
require.NoError(t, err)
devices, err = repo.FindGCMDevicesByUser(user.ID)
require.NoError(t, err)
assert.Len(t, devices, 0)
// Delete
err = repo.DeleteGCMDevice(device.ID)
require.NoError(t, err)
_, err = repo.FindGCMDeviceByID(device.ID)
assert.Error(t, err)
}
func TestNotificationRepository_GetActiveTokensForUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create APNS devices
apns1 := &models.APNSDevice{
Active: true,
UserID: &user.ID,
RegistrationID: "ios-token-1",
}
err := db.Create(apns1).Error
require.NoError(t, err)
apns2 := &models.APNSDevice{
Active: true,
UserID: &user.ID,
RegistrationID: "ios-token-2",
}
err = db.Create(apns2).Error
require.NoError(t, err)
// Create inactive APNS device (should not be returned)
// Insert inactive device via raw SQL to bypass GORM's default:true override
err = db.Exec("INSERT INTO push_notifications_apnsdevice (active, user_id, registration_id, date_created) VALUES (false, ?, 'ios-token-inactive', CURRENT_TIMESTAMP)", user.ID).Error
require.NoError(t, err)
// Create GCM device
gcm1 := &models.GCMDevice{
Active: true,
UserID: &user.ID,
RegistrationID: "android-token-1",
CloudMessageType: "FCM",
}
err = db.Create(gcm1).Error
require.NoError(t, err)
iosTokens, androidTokens, err := repo.GetActiveTokensForUser(user.ID)
require.NoError(t, err)
assert.Len(t, iosTokens, 2)
assert.Contains(t, iosTokens, "ios-token-1")
assert.Contains(t, iosTokens, "ios-token-2")
assert.Len(t, androidTokens, 1)
assert.Contains(t, androidTokens, "android-token-1")
}
func TestNotificationRepository_GetActiveTokensForUser_NoDevices(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
iosTokens, androidTokens, err := repo.GetActiveTokensForUser(user.ID)
require.NoError(t, err)
assert.Empty(t, iosTokens)
assert.Empty(t, androidTokens)
}
func TestNotificationRepository_FindPreferencesByUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create preferences with defaults (GORM applies default:true for bools)
prefs := &models.NotificationPreference{
UserID: user.ID,
}
err := db.Create(prefs).Error
require.NoError(t, err)
// Update one field to false via raw SQL to bypass GORM default
err = db.Exec("UPDATE notifications_notificationpreference SET task_due_soon = false WHERE user_id = ?", user.ID).Error
require.NoError(t, err)
found, err := repo.FindPreferencesByUser(user.ID)
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
assert.False(t, found.TaskDueSoon)
assert.True(t, found.TaskOverdue)
}
func TestNotificationRepository_FindPreferencesByUser_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
_, err := repo.FindPreferencesByUser(9999)
assert.Error(t, err)
}
func TestNotificationRepository_UpdatePreferences(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
prefs, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
prefs.TaskDueSoon = false
prefs.DailyDigest = false
err = repo.UpdatePreferences(prefs)
require.NoError(t, err)
found, err := repo.FindPreferencesByUser(user.ID)
require.NoError(t, err)
assert.False(t, found.TaskDueSoon)
assert.False(t, found.DailyDigest)
}

View File

@@ -0,0 +1,217 @@
package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestReminderRepository_LogReminder(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
logEntry, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
require.NoError(t, err)
assert.NotZero(t, logEntry.ID)
assert.Equal(t, task.ID, logEntry.TaskID)
assert.Equal(t, user.ID, logEntry.UserID)
assert.Equal(t, models.ReminderStage7Days, logEntry.ReminderStage)
}
func TestReminderRepository_HasSentReminder(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
dueDate := time.Date(2026, 4, 15, 12, 30, 0, 0, time.UTC) // Has time component
// Not sent yet
sent, err := repo.HasSentReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days)
require.NoError(t, err)
assert.False(t, sent)
// Log it
_, err = repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
require.NoError(t, err)
// Now should be sent
sent, err = repo.HasSentReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days)
require.NoError(t, err)
assert.True(t, sent)
// Different stage should not be sent
sent, err = repo.HasSentReminder(task.ID, user.ID, dueDate, models.ReminderStage3Days)
require.NoError(t, err)
assert.False(t, sent)
}
func TestReminderRepository_HasSentReminderBatch(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task1 := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 1")
task2 := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 2")
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
// Log reminder for task1
_, err := repo.LogReminder(task1.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
require.NoError(t, err)
keys := []ReminderKey{
{TaskID: task1.ID, UserID: user.ID, DueDate: dueDate, Stage: models.ReminderStage7Days},
{TaskID: task2.ID, UserID: user.ID, DueDate: dueDate, Stage: models.ReminderStage7Days},
}
result, err := repo.HasSentReminderBatch(keys)
require.NoError(t, err)
assert.True(t, result[0], "task1 reminder should be sent")
assert.False(t, result[1], "task2 reminder should not be sent")
}
func TestReminderRepository_HasSentReminderBatch_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
result, err := repo.HasSentReminderBatch([]ReminderKey{})
require.NoError(t, err)
assert.Empty(t, result)
}
func TestReminderRepository_GetSentRemindersForTask(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
_, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
require.NoError(t, err)
_, err = repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage3Days, nil)
require.NoError(t, err)
logs, err := repo.GetSentRemindersForTask(task.ID, user.ID)
require.NoError(t, err)
assert.Len(t, logs, 2)
}
func TestReminderRepository_GetSentRemindersForDueDate(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
dueDate1 := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
dueDate2 := time.Date(2026, 5, 15, 0, 0, 0, 0, time.UTC)
_, err := repo.LogReminder(task.ID, user.ID, dueDate1, models.ReminderStage7Days, nil)
require.NoError(t, err)
_, err = repo.LogReminder(task.ID, user.ID, dueDate2, models.ReminderStage7Days, nil)
require.NoError(t, err)
logs, err := repo.GetSentRemindersForDueDate(task.ID, user.ID, dueDate1)
require.NoError(t, err)
assert.Len(t, logs, 1)
}
func TestReminderRepository_CleanupOldLogs(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
// Create old log entry (100 days ago)
oldLog := &models.TaskReminderLog{
TaskID: task.ID,
UserID: user.ID,
DueDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
ReminderStage: models.ReminderStage7Days,
SentAt: time.Now().UTC().AddDate(0, 0, -100),
}
require.NoError(t, db.Create(oldLog).Error)
// Create recent log entry
recentLog := &models.TaskReminderLog{
TaskID: task.ID,
UserID: user.ID,
DueDate: time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC),
ReminderStage: models.ReminderStage3Days,
SentAt: time.Now().UTC(),
}
require.NoError(t, db.Create(recentLog).Error)
deleted, err := repo.CleanupOldLogs(90)
require.NoError(t, err)
assert.Equal(t, int64(1), deleted)
// Recent log should still exist
var count int64
db.Model(&models.TaskReminderLog{}).Count(&count)
assert.Equal(t, int64(1), count)
}
func TestReminderRepository_GetRecentReminderStats(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
// Create recent reminders
_, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
require.NoError(t, err)
_, err = repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage3Days, nil)
require.NoError(t, err)
stats, err := repo.GetRecentReminderStats(24) // last 24 hours
require.NoError(t, err)
assert.Equal(t, int64(1), stats[string(models.ReminderStage7Days)])
assert.Equal(t, int64(1), stats[string(models.ReminderStage3Days)])
}
func TestReminderRepository_LogReminder_WithNotificationID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewReminderRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
// Create a notification
notif := &models.Notification{
UserID: user.ID,
NotificationType: models.NotificationTaskDueSoon,
Title: "Test",
Body: "Test body",
}
require.NoError(t, db.Create(notif).Error)
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
logEntry, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, &notif.ID)
require.NoError(t, err)
assert.NotNil(t, logEntry.NotificationID)
assert.Equal(t, notif.ID, *logEntry.NotificationID)
}

View File

@@ -0,0 +1,216 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestResidenceRepository_FindByIDSimple(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
found, err := repo.FindByIDSimple(residence.ID)
require.NoError(t, err)
assert.Equal(t, residence.ID, found.ID)
assert.Equal(t, "Test House", found.Name)
}
func TestResidenceRepository_FindByIDSimple_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
_, err := repo.FindByIDSimple(9999)
assert.Error(t, err)
}
func TestResidenceRepository_FindByIDSimple_InactiveNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
db.Model(residence).Update("is_active", false)
_, err := repo.FindByIDSimple(residence.ID)
assert.Error(t, err)
}
func TestResidenceRepository_FindResidenceIDsByUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
sharedUser := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, owner.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, owner.ID, "House 2")
repo.AddUser(r1.ID, sharedUser.ID)
// Owner should see both
ownerIDs, err := repo.FindResidenceIDsByUser(owner.ID)
require.NoError(t, err)
assert.Len(t, ownerIDs, 2)
// Shared user should see one
sharedIDs, err := repo.FindResidenceIDsByUser(sharedUser.ID)
require.NoError(t, err)
assert.Len(t, sharedIDs, 1)
assert.Equal(t, r1.ID, sharedIDs[0])
_ = r2 // suppress unused
}
func TestResidenceRepository_FindResidenceIDsByOwner(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
otherOwner := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
testutil.CreateTestResidence(t, db, owner.ID, "House 1")
testutil.CreateTestResidence(t, db, owner.ID, "House 2")
testutil.CreateTestResidence(t, db, otherOwner.ID, "Other House")
ids, err := repo.FindResidenceIDsByOwner(owner.ID)
require.NoError(t, err)
assert.Len(t, ids, 2)
}
func TestResidenceRepository_DeactivateShareCode(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
// Create a share code
created, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000) // 24h in ns
require.NoError(t, err)
assert.True(t, created.IsActive)
// Deactivate it
err = repo.DeactivateShareCode(created.ID)
require.NoError(t, err)
// FindShareCodeByCode should not find it
_, err = repo.FindShareCodeByCode(created.Code)
assert.Error(t, err)
}
func TestResidenceRepository_GetActiveShareCode(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
// No active code initially
code, err := repo.GetActiveShareCode(residence.ID)
require.NoError(t, err)
assert.Nil(t, code)
// Create a share code
created, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000)
require.NoError(t, err)
// Should find the active code
code, err = repo.GetActiveShareCode(residence.ID)
require.NoError(t, err)
require.NotNil(t, code)
assert.Equal(t, created.Code, code.Code)
}
func TestResidenceRepository_CreateShareCode_DeactivatesPrevious(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
// Create first code
first, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000)
require.NoError(t, err)
// Create second code (should deactivate first)
second, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000)
require.NoError(t, err)
assert.NotEqual(t, first.Code, second.Code)
// First code should no longer be findable
_, err = repo.FindShareCodeByCode(first.Code)
assert.Error(t, err)
// Second code should be active
found, err := repo.FindShareCodeByCode(second.Code)
require.NoError(t, err)
assert.Equal(t, second.Code, found.Code)
}
func TestResidenceRepository_FindResidenceTypeByID(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewResidenceRepository(db)
// Get first residence type
types, err := repo.GetAllResidenceTypes()
require.NoError(t, err)
require.NotEmpty(t, types)
found, err := repo.FindResidenceTypeByID(types[0].ID)
require.NoError(t, err)
assert.Equal(t, types[0].Name, found.Name)
}
func TestResidenceRepository_FindResidenceTypeByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
_, err := repo.FindResidenceTypeByID(9999)
assert.Error(t, err)
}
func TestResidenceRepository_GetTasksForReport(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 1")
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 2")
tasks, err := repo.GetTasksForReport(residence.ID)
require.NoError(t, err)
assert.Len(t, tasks, 2)
}
func TestResidenceRepository_AddUser_Idempotent(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
sharedUser := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
// Add same user twice (should not error due to ON CONFLICT DO NOTHING)
err := repo.AddUser(residence.ID, sharedUser.ID)
require.NoError(t, err)
err = repo.AddUser(residence.ID, sharedUser.ID)
require.NoError(t, err)
// Should still only count once
users, err := repo.GetResidenceUsers(residence.ID)
require.NoError(t, err)
assert.Len(t, users, 2) // owner + sharedUser
}

View File

@@ -0,0 +1,418 @@
package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestSubscriptionRepository_FindByUserID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
}
require.NoError(t, db.Create(sub).Error)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierPro, found.Tier)
}
func TestSubscriptionRepository_FindByUserID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
_, err := repo.FindByUserID(9999)
assert.Error(t, err)
}
func TestSubscriptionRepository_Update(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
sub, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
sub.Tier = models.TierPro
err = repo.Update(sub)
require.NoError(t, err)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierPro, found.Tier)
}
func TestSubscriptionRepository_UpgradeToPro(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
err = repo.UpgradeToPro(user.ID, expiresAt, "apple")
require.NoError(t, err)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierPro, found.Tier)
assert.Equal(t, "apple", found.Platform)
assert.True(t, found.AutoRenew)
require.NotNil(t, found.SubscribedAt)
require.NotNil(t, found.ExpiresAt)
}
func TestSubscriptionRepository_DowngradeToFree(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
// First upgrade
err = repo.UpgradeToPro(user.ID, time.Now().UTC().AddDate(1, 0, 0), "apple")
require.NoError(t, err)
// Then downgrade
err = repo.DowngradeToFree(user.ID)
require.NoError(t, err)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierFree, found.Tier)
assert.False(t, found.AutoRenew)
require.NotNil(t, found.CancelledAt)
}
func TestSubscriptionRepository_SetAutoRenew(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
err = repo.SetAutoRenew(user.ID, false)
require.NoError(t, err)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
assert.False(t, found.AutoRenew)
}
func TestSubscriptionRepository_UpdateReceiptData(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
err = repo.UpdateReceiptData(user.ID, "receipt_data_abc123")
require.NoError(t, err)
var sub models.UserSubscription
db.Where("user_id = ?", user.ID).First(&sub)
require.NotNil(t, sub.AppleReceiptData)
assert.Equal(t, "receipt_data_abc123", *sub.AppleReceiptData)
}
func TestSubscriptionRepository_UpdatePurchaseToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
err = repo.UpdatePurchaseToken(user.ID, "google_token_xyz")
require.NoError(t, err)
var sub models.UserSubscription
db.Where("user_id = ?", user.ID).First(&sub)
require.NotNil(t, sub.GooglePurchaseToken)
assert.Equal(t, "google_token_xyz", *sub.GooglePurchaseToken)
}
func TestSubscriptionRepository_FindByAppleReceiptContains(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Use raw SQL to ensure apple_receipt_data is stored correctly (GORM may omit pointer fields)
require.NoError(t, db.Exec(
"INSERT INTO subscription_usersubscription (user_id, tier, apple_receipt_data, created_at, updated_at) VALUES (?, ?, ?, datetime('now'), datetime('now'))",
user.ID, models.TierPro, "transactionid=txnabc123data",
).Error)
// Avoid underscores in search term: escapeLikeWildcards escapes _ with \
// which PostgreSQL handles but SQLite does not (no default ESCAPE clause)
found, err := repo.FindByAppleReceiptContains("txnabc123")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestSubscriptionRepository_FindByAppleReceiptContains_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
_, err := repo.FindByAppleReceiptContains("nonexistent_txn")
assert.Error(t, err)
}
func TestSubscriptionRepository_FindByGoogleToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
purchaseToken := "google_purchase_abc"
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
GooglePurchaseToken: &purchaseToken,
}
require.NoError(t, db.Create(sub).Error)
found, err := repo.FindByGoogleToken("google_purchase_abc")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestSubscriptionRepository_FindByGoogleToken_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
_, err := repo.FindByGoogleToken("nonexistent_token")
assert.Error(t, err)
}
func TestSubscriptionRepository_SetCancelledAt(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
cancelTime := time.Now().UTC()
err = repo.SetCancelledAt(user.ID, cancelTime)
require.NoError(t, err)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
require.NotNil(t, found.CancelledAt)
}
func TestSubscriptionRepository_ClearCancelledAt(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
// Set then clear
err = repo.SetCancelledAt(user.ID, time.Now().UTC())
require.NoError(t, err)
err = repo.ClearCancelledAt(user.ID)
require.NoError(t, err)
found, err := repo.FindByUserID(user.ID)
require.NoError(t, err)
assert.Nil(t, found.CancelledAt)
}
func TestSubscriptionRepository_GetTierLimits_Defaults(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
// No tier limits in DB, should return defaults
freeLimits, err := repo.GetTierLimits(models.TierFree)
require.NoError(t, err)
assert.NotNil(t, freeLimits)
proLimits, err := repo.GetTierLimits(models.TierPro)
require.NoError(t, err)
assert.NotNil(t, proLimits)
}
func TestSubscriptionRepository_GetAllTierLimits(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
// No limits seeded = empty
limits, err := repo.GetAllTierLimits()
require.NoError(t, err)
assert.Empty(t, limits)
}
func TestSubscriptionRepository_GetSettings_Defaults(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
// No settings in DB, should return default (limitations disabled)
settings, err := repo.GetSettings()
require.NoError(t, err)
assert.NotNil(t, settings)
assert.False(t, settings.EnableLimitations)
}
func TestSubscriptionRepository_GetUpgradeTrigger(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
trigger := &models.UpgradeTrigger{
TriggerKey: "max_residences",
Title: "Upgrade for more homes",
Message: "Free tier supports 1 home",
IsActive: true,
}
require.NoError(t, db.Create(trigger).Error)
found, err := repo.GetUpgradeTrigger("max_residences")
require.NoError(t, err)
assert.Equal(t, "max_residences", found.TriggerKey)
}
func TestSubscriptionRepository_GetUpgradeTrigger_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
_, err := repo.GetUpgradeTrigger("nonexistent_key")
assert.Error(t, err)
}
func TestSubscriptionRepository_GetAllUpgradeTriggers(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
trigger1 := &models.UpgradeTrigger{TriggerKey: "key1", Title: "T1", Message: "M1", IsActive: true}
trigger2 := &models.UpgradeTrigger{TriggerKey: "key2", Title: "T2", Message: "M2", IsActive: true}
require.NoError(t, db.Create(trigger1).Error)
require.NoError(t, db.Create(trigger2).Error)
// Use raw SQL for inactive record to avoid GORM default:true overriding IsActive=false
require.NoError(t, db.Exec(
"INSERT INTO subscription_upgradetrigger (trigger_key, title, message, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, datetime('now'), datetime('now'))",
"key3", "T3", "M3", false,
).Error)
triggers, err := repo.GetAllUpgradeTriggers()
require.NoError(t, err)
assert.Len(t, triggers, 2) // Only active
}
func TestSubscriptionRepository_GetFeatureBenefits(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
b1 := &models.FeatureBenefit{FeatureName: "Benefit 1", FreeTierText: "1 home", ProTierText: "Unlimited", DisplayOrder: 1, IsActive: true}
require.NoError(t, db.Create(b1).Error)
// Use raw SQL for inactive record to avoid GORM default:true overriding IsActive=false
require.NoError(t, db.Exec(
"INSERT INTO subscription_featurebenefit (feature_name, free_tier_text, pro_tier_text, display_order, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))",
"Benefit 2", "3 tasks", "Unlimited", 2, false,
).Error)
benefits, err := repo.GetFeatureBenefits()
require.NoError(t, err)
assert.Len(t, benefits, 1) // Only active
}
func TestSubscriptionRepository_GetActivePromotions(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
now := time.Now().UTC()
active := &models.Promotion{
PromotionID: "promo_active",
TargetTier: models.TierPro,
Title: "Active Promo",
Message: "Get Pro now!",
StartDate: now.AddDate(0, 0, -1),
EndDate: now.AddDate(0, 0, 1),
IsActive: true,
}
expired := &models.Promotion{
PromotionID: "promo_expired",
TargetTier: models.TierPro,
Title: "Expired Promo",
Message: "Too late!",
StartDate: now.AddDate(0, 0, -10),
EndDate: now.AddDate(0, 0, -5),
IsActive: true,
}
require.NoError(t, db.Create(active).Error)
require.NoError(t, db.Create(expired).Error)
promos, err := repo.GetActivePromotions(models.TierPro)
require.NoError(t, err)
assert.Len(t, promos, 1)
assert.Equal(t, "promo_active", promos[0].PromotionID)
}
func TestSubscriptionRepository_GetPromotionByID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
promo := &models.Promotion{
PromotionID: "promo_find",
TargetTier: models.TierPro,
Title: "Find Me",
Message: "Found!",
StartDate: time.Now().UTC(),
EndDate: time.Now().UTC().AddDate(0, 0, 30),
IsActive: true,
}
require.NoError(t, db.Create(promo).Error)
found, err := repo.GetPromotionByID("promo_find")
require.NoError(t, err)
assert.Equal(t, "Find Me", found.Title)
}
func TestSubscriptionRepository_GetPromotionByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
_, err := repo.GetPromotionByID("nonexistent")
assert.Error(t, err)
}
func TestSubscriptionRepository_FindByStripeSubscriptionID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
subID := "sub_test_123"
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
StripeSubscriptionID: &subID,
}
require.NoError(t, db.Create(sub).Error)
found, err := repo.FindByStripeSubscriptionID("sub_test_123")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestSubscriptionRepository_FindByStripeSubscriptionID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
_, err := repo.FindByStripeSubscriptionID("sub_nonexistent")
assert.Error(t, err)
}

View File

@@ -0,0 +1,516 @@
package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
// === UpdateTx / Version Conflict Tests ===
func TestTaskRepository_UpdateTx_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Versioned Task")
// Simulate stale version by setting wrong version
task.Version = 999
task.Title = "Should Fail"
tx := db.Begin()
err := repo.UpdateTx(tx, task)
tx.Rollback()
assert.ErrorIs(t, err, ErrVersionConflict)
}
func TestTaskRepository_UpdateTx_Success(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Versioned Task")
originalVersion := task.Version
task.Title = "Updated via Tx"
tx := db.Begin()
err := repo.UpdateTx(tx, task)
require.NoError(t, err)
tx.Commit()
assert.Equal(t, originalVersion+1, task.Version)
found, err := repo.FindByID(task.ID)
require.NoError(t, err)
assert.Equal(t, "Updated via Tx", found.Title)
}
func TestTaskRepository_Update_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Versioned Task")
task.Version = 999 // stale version
task.Title = "Should Fail"
err := repo.Update(task)
assert.ErrorIs(t, err, ErrVersionConflict)
}
// === Version Conflict on State Operations ===
func TestTaskRepository_Cancel_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.Cancel(task.ID, 999) // wrong version
assert.ErrorIs(t, err, ErrVersionConflict)
}
func TestTaskRepository_Uncancel_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.Uncancel(task.ID, 999)
assert.ErrorIs(t, err, ErrVersionConflict)
}
func TestTaskRepository_Archive_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.Archive(task.ID, 999)
assert.ErrorIs(t, err, ErrVersionConflict)
}
func TestTaskRepository_Unarchive_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.Unarchive(task.ID, 999)
assert.ErrorIs(t, err, ErrVersionConflict)
}
func TestTaskRepository_MarkInProgress(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.MarkInProgress(task.ID, task.Version)
require.NoError(t, err)
found, err := repo.FindByID(task.ID)
require.NoError(t, err)
assert.True(t, found.InProgress)
}
func TestTaskRepository_MarkInProgress_VersionConflict(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.MarkInProgress(task.ID, 999)
assert.ErrorIs(t, err, ErrVersionConflict)
}
// === CreateCompletionTx ===
func TestTaskRepository_CreateCompletionTx(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
tx := db.Begin()
completion := &models.TaskCompletion{
TaskID: task.ID,
CompletedByID: user.ID,
CompletedAt: time.Now().UTC(),
Notes: "Done via tx",
}
err := repo.CreateCompletionTx(tx, completion)
require.NoError(t, err)
tx.Commit()
assert.NotZero(t, completion.ID)
found, err := repo.FindCompletionByID(completion.ID)
require.NoError(t, err)
assert.Equal(t, "Done via tx", found.Notes)
}
// === GetFrequencyByID ===
func TestTaskRepository_GetFrequencyByID(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewTaskRepository(db)
frequencies, err := repo.GetAllFrequencies()
require.NoError(t, err)
require.NotEmpty(t, frequencies)
found, err := repo.GetFrequencyByID(frequencies[0].ID)
require.NoError(t, err)
assert.Equal(t, frequencies[0].Name, found.Name)
}
func TestTaskRepository_GetFrequencyByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
_, err := repo.GetFrequencyByID(9999)
assert.Error(t, err)
}
// === FindByUser ===
func TestTaskRepository_FindByUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task A")
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task B")
tasks, err := repo.FindByUser(user.ID, []uint{r1.ID, r2.ID})
require.NoError(t, err)
assert.Len(t, tasks, 2)
}
// === CountByResidenceIDs ===
func TestTaskRepository_CountByResidenceIDs(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task 1")
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 2")
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 3")
count, err := repo.CountByResidenceIDs([]uint{r1.ID, r2.ID})
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
func TestTaskRepository_CountByResidenceIDs_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
count, err := repo.CountByResidenceIDs([]uint{})
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
// === FindCompletionsByUser ===
func TestTaskRepository_FindCompletionsByUser(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
c := &models.TaskCompletion{
TaskID: task.ID,
CompletedByID: user.ID,
CompletedAt: time.Now().UTC(),
}
require.NoError(t, db.Create(c).Error)
completions, err := repo.FindCompletionsByUser(user.ID, []uint{residence.ID})
require.NoError(t, err)
assert.Len(t, completions, 1)
}
// === UpdateCompletion ===
func TestTaskRepository_UpdateCompletion(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
completion := &models.TaskCompletion{
TaskID: task.ID,
CompletedByID: user.ID,
CompletedAt: time.Now().UTC(),
Notes: "Original",
}
require.NoError(t, db.Create(completion).Error)
completion.Notes = "Updated notes"
err := repo.UpdateCompletion(completion)
require.NoError(t, err)
found, err := repo.FindCompletionByID(completion.ID)
require.NoError(t, err)
assert.Equal(t, "Updated notes", found.Notes)
}
// === CompletionImage CRUD ===
func TestTaskRepository_CompletionImage_CRUD(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
completion := &models.TaskCompletion{
TaskID: task.ID,
CompletedByID: user.ID,
CompletedAt: time.Now().UTC(),
}
require.NoError(t, db.Create(completion).Error)
// Create image
img := &models.TaskCompletionImage{
CompletionID: completion.ID,
ImageURL: "https://example.com/img.jpg",
}
err := repo.CreateCompletionImage(img)
require.NoError(t, err)
assert.NotZero(t, img.ID)
// Find image
found, err := repo.FindCompletionImageByID(img.ID)
require.NoError(t, err)
assert.Equal(t, "https://example.com/img.jpg", found.ImageURL)
// Delete image
err = repo.DeleteCompletionImage(img.ID)
require.NoError(t, err)
_, err = repo.FindCompletionImageByID(img.ID)
assert.Error(t, err)
}
func TestTaskRepository_FindCompletionImageByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
_, err := repo.FindCompletionImageByID(9999)
assert.Error(t, err)
}
// === GetOverdueCountByResidence ===
func TestTaskRepository_GetOverdueCountByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
now := time.Now().UTC()
pastDue := now.AddDate(0, 0, -5)
// Overdue task in r1
t1 := &models.Task{
ResidenceID: r1.ID,
CreatedByID: user.ID,
Title: "Overdue in r1",
DueDate: &pastDue,
Version: 1,
}
require.NoError(t, db.Create(t1).Error)
// Not overdue task in r2 (future)
futureDue := now.AddDate(0, 0, 30)
t2 := &models.Task{
ResidenceID: r2.ID,
CreatedByID: user.ID,
Title: "Future in r2",
DueDate: &futureDue,
Version: 1,
}
require.NoError(t, db.Create(t2).Error)
countMap, err := repo.GetOverdueCountByResidence([]uint{r1.ID, r2.ID}, now)
require.NoError(t, err)
assert.Equal(t, 1, countMap[r1.ID])
assert.Equal(t, 0, countMap[r2.ID])
}
func TestTaskRepository_GetOverdueCountByResidence_EmptyIDs(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
countMap, err := repo.GetOverdueCountByResidence([]uint{}, time.Now().UTC())
require.NoError(t, err)
assert.Empty(t, countMap)
}
// === GetKanbanDataForMultipleResidences ===
func TestTaskRepository_GetKanbanDataForMultipleResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task 1")
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 2")
board, err := repo.GetKanbanDataForMultipleResidences([]uint{r1.ID, r2.ID}, 30, time.Now().UTC())
require.NoError(t, err)
assert.Equal(t, "all", board.ResidenceID)
assert.Len(t, board.Columns, 5)
// Both tasks should appear in upcoming (no due date)
var upcomingCol *models.KanbanColumn
for i := range board.Columns {
if board.Columns[i].Name == "upcoming_tasks" {
upcomingCol = &board.Columns[i]
}
}
require.NotNil(t, upcomingCol)
assert.Equal(t, 2, upcomingCol.Count)
}
// === DB() accessor ===
func TestTaskRepository_DB(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
assert.NotNil(t, repo.DB())
}
// === GetBatchCompletionSummaries ===
func TestTaskRepository_GetBatchCompletionSummaries_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
result, err := repo.GetBatchCompletionSummaries([]uint{}, time.Now().UTC(), 10)
require.NoError(t, err)
assert.Empty(t, result)
}
func TestTaskRepository_GetBatchCompletionSummaries_WithData(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
task1 := testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task 1")
task2 := testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 2")
now := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
c1 := models.TaskCompletion{
TaskID: task1.ID, CompletedByID: user.ID,
CompletedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC), CompletedFromColumn: "completed_tasks",
}
require.NoError(t, db.Create(&c1).Error)
c2 := models.TaskCompletion{
TaskID: task2.ID, CompletedByID: user.ID,
CompletedAt: time.Date(2026, 1, 15, 12, 0, 0, 0, time.UTC), CompletedFromColumn: "overdue_tasks",
}
require.NoError(t, db.Create(&c2).Error)
result, err := repo.GetBatchCompletionSummaries([]uint{r1.ID, r2.ID}, now, 10)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, 1, result[r1.ID].TotalAllTime)
assert.Equal(t, 1, result[r2.ID].TotalAllTime)
}
// === FindCompletionByID not found ===
func TestTaskRepository_FindCompletionByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
_, err := repo.FindCompletionByID(9999)
assert.Error(t, err)
}
// === DeleteCompletion deletes images ===
func TestTaskRepository_DeleteCompletion_DeletesImages(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
completion := &models.TaskCompletion{
TaskID: task.ID,
CompletedByID: user.ID,
CompletedAt: time.Now().UTC(),
}
require.NoError(t, db.Create(completion).Error)
// Add images
img := &models.TaskCompletionImage{
CompletionID: completion.ID,
ImageURL: "https://example.com/img.jpg",
}
require.NoError(t, db.Create(img).Error)
// Delete completion (should cascade to images)
err := repo.DeleteCompletion(completion.ID)
require.NoError(t, err)
// Images should be gone
var count int64
db.Model(&models.TaskCompletionImage{}).Where("completion_id = ?", completion.ID).Count(&count)
assert.Equal(t, int64(0), count)
}

View File

@@ -0,0 +1,236 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestTaskTemplateRepository_Create(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewTaskTemplateRepository(db)
var cat models.TaskCategory
require.NoError(t, db.First(&cat).Error)
var freq models.TaskFrequency
require.NoError(t, db.First(&freq).Error)
template := &models.TaskTemplate{
Title: "Change HVAC Filter",
Description: "Replace the HVAC air filter",
CategoryID: &cat.ID,
FrequencyID: &freq.ID,
IsActive: true,
Tags: "hvac,filter,maintenance",
}
err := repo.Create(template)
require.NoError(t, err)
assert.NotZero(t, template.ID)
}
func TestTaskTemplateRepository_GetAll(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewTaskTemplateRepository(db)
// Create active and inactive templates
t1 := &models.TaskTemplate{Title: "Active Template", IsActive: true}
t2 := &models.TaskTemplate{Title: "Inactive Template", IsActive: false}
require.NoError(t, db.Create(t1).Error)
require.NoError(t, db.Create(t2).Error)
templates, err := repo.GetAll()
require.NoError(t, err)
assert.Len(t, templates, 1) // Only active
assert.Equal(t, "Active Template", templates[0].Title)
}
func TestTaskTemplateRepository_GetByCategory(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewTaskTemplateRepository(db)
var cat models.TaskCategory
require.NoError(t, db.First(&cat).Error)
t1 := &models.TaskTemplate{Title: "Cat Template", CategoryID: &cat.ID, IsActive: true}
t2 := &models.TaskTemplate{Title: "No Cat Template", IsActive: true}
require.NoError(t, db.Create(t1).Error)
require.NoError(t, db.Create(t2).Error)
templates, err := repo.GetByCategory(cat.ID)
require.NoError(t, err)
assert.Len(t, templates, 1)
assert.Equal(t, "Cat Template", templates[0].Title)
}
func TestTaskTemplateRepository_Search(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
t1 := &models.TaskTemplate{Title: "Change HVAC Filter", Tags: "hvac,filter", IsActive: true}
t2 := &models.TaskTemplate{Title: "Fix Faucet", Tags: "plumbing", IsActive: true}
require.NoError(t, db.Create(t1).Error)
require.NoError(t, db.Create(t2).Error)
// Search by title
results, err := repo.Search("hvac")
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "Change HVAC Filter", results[0].Title)
// Search by tag
results, err = repo.Search("plumbing")
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "Fix Faucet", results[0].Title)
// No match
results, err = repo.Search("nonexistent")
require.NoError(t, err)
assert.Empty(t, results)
}
func TestTaskTemplateRepository_GetByID(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
tmpl := &models.TaskTemplate{Title: "Test Template", IsActive: true}
require.NoError(t, db.Create(tmpl).Error)
found, err := repo.GetByID(tmpl.ID)
require.NoError(t, err)
assert.Equal(t, "Test Template", found.Title)
}
func TestTaskTemplateRepository_GetByID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
_, err := repo.GetByID(9999)
assert.Error(t, err)
}
func TestTaskTemplateRepository_Update(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
tmpl := &models.TaskTemplate{Title: "Original", IsActive: true}
require.NoError(t, db.Create(tmpl).Error)
tmpl.Title = "Updated"
err := repo.Update(tmpl)
require.NoError(t, err)
found, err := repo.GetByID(tmpl.ID)
require.NoError(t, err)
assert.Equal(t, "Updated", found.Title)
}
func TestTaskTemplateRepository_Delete(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
tmpl := &models.TaskTemplate{Title: "To Delete", IsActive: true}
require.NoError(t, db.Create(tmpl).Error)
err := repo.Delete(tmpl.ID)
require.NoError(t, err)
_, err = repo.GetByID(tmpl.ID)
assert.Error(t, err)
}
func TestTaskTemplateRepository_GetAllIncludingInactive(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
t1 := &models.TaskTemplate{Title: "Active", IsActive: true}
t2 := &models.TaskTemplate{Title: "Inactive", IsActive: false}
require.NoError(t, db.Create(t1).Error)
require.NoError(t, db.Create(t2).Error)
templates, err := repo.GetAllIncludingInactive()
require.NoError(t, err)
assert.Len(t, templates, 2) // Both active and inactive
}
func TestTaskTemplateRepository_Count(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
t1 := &models.TaskTemplate{Title: "Active 1", IsActive: true}
t2 := &models.TaskTemplate{Title: "Active 2", IsActive: true}
require.NoError(t, db.Create(t1).Error)
require.NoError(t, db.Create(t2).Error)
// Use raw SQL for inactive record to avoid GORM default:true overriding IsActive=false
require.NoError(t, db.Exec(
"INSERT INTO task_tasktemplate (title, is_active, display_order, created_at, updated_at) VALUES (?, ?, ?, datetime('now'), datetime('now'))",
"Inactive", false, 0,
).Error)
count, err := repo.Count()
require.NoError(t, err)
assert.Equal(t, int64(2), count) // Only active
}
func TestTaskTemplateRepository_GetByRegion(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
repo := NewTaskTemplateRepository(db)
// Create a climate region
region := &models.ClimateRegion{Name: "Hot-Humid", ZoneNumber: 1, IsActive: true}
require.NoError(t, db.Create(region).Error)
// Create template with region association
tmpl := &models.TaskTemplate{Title: "Regional Task", IsActive: true}
require.NoError(t, db.Create(tmpl).Error)
// Associate template with region via join table
err := db.Exec("INSERT INTO task_tasktemplate_regions (task_template_id, climate_region_id) VALUES (?, ?)", tmpl.ID, region.ID).Error
require.NoError(t, err)
// Create template without region
tmpl2 := &models.TaskTemplate{Title: "Non-Regional Task", IsActive: true}
require.NoError(t, db.Create(tmpl2).Error)
templates, err := repo.GetByRegion(region.ID)
require.NoError(t, err)
assert.Len(t, templates, 1)
assert.Equal(t, "Regional Task", templates[0].Title)
}
func TestTaskTemplateRepository_GetGroupedByCategory(t *testing.T) {
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
repo := NewTaskTemplateRepository(db)
var cat models.TaskCategory
require.NoError(t, db.First(&cat).Error)
t1 := &models.TaskTemplate{Title: "Categorized", CategoryID: &cat.ID, IsActive: true}
t2 := &models.TaskTemplate{Title: "Uncategorized", IsActive: true}
require.NoError(t, db.Create(t1).Error)
require.NoError(t, db.Create(t2).Error)
grouped, err := repo.GetGroupedByCategory()
require.NoError(t, err)
assert.GreaterOrEqual(t, len(grouped), 2) // At least the category + "Uncategorized"
// Uncategorized should have the template without category
assert.Len(t, grouped["Uncategorized"], 1)
assert.Equal(t, "Uncategorized", grouped["Uncategorized"][0].Title)
}

View File

@@ -0,0 +1,465 @@
package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestUserRepository_FindByUsername_CaseInsensitive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "TestUser", "test@example.com", "Password123")
// Should find with different cases
found, err := repo.FindByUsername("testuser")
require.NoError(t, err)
assert.Equal(t, "TestUser", found.Username)
found, err = repo.FindByUsername("TESTUSER")
require.NoError(t, err)
assert.Equal(t, "TestUser", found.Username)
}
func TestUserRepository_FindByUsername_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByUsername("nonexistent")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotFound)
}
func TestUserRepository_FindByEmail_CaseInsensitive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "testuser", "Test@Example.com", "Password123")
found, err := repo.FindByEmail("test@example.com")
require.NoError(t, err)
assert.Equal(t, "Test@Example.com", found.Email)
}
func TestUserRepository_FindByEmail_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByEmail("nonexistent@example.com")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotFound)
}
func TestUserRepository_ExistsByUsername_CaseInsensitive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "TestUser", "test@example.com", "Password123")
exists, err := repo.ExistsByUsername("TESTUSER")
require.NoError(t, err)
assert.True(t, exists)
}
func TestUserRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "testuser", "Test@Example.com", "Password123")
exists, err := repo.ExistsByEmail("test@example.com")
require.NoError(t, err)
assert.True(t, exists)
}
func TestUserRepository_GetOrCreateToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create token
token1, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
assert.NotEmpty(t, token1.Key)
// Should return same token
token2, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
assert.Equal(t, token1.Key, token2.Key)
}
func TestUserRepository_FindTokenByKey(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
token, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
found, err := repo.FindTokenByKey(token.Key)
require.NoError(t, err)
assert.Equal(t, token.Key, found.Key)
assert.Equal(t, user.ID, found.UserID)
}
func TestUserRepository_FindTokenByKey_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindTokenByKey("nonexistent-token-key")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrTokenNotFound)
}
func TestUserRepository_DeleteToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
token, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
err = repo.DeleteToken(token.Key)
require.NoError(t, err)
_, err = repo.FindTokenByKey(token.Key)
assert.ErrorIs(t, err, ErrTokenNotFound)
}
func TestUserRepository_DeleteToken_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
err := repo.DeleteToken("nonexistent-key")
assert.ErrorIs(t, err, ErrTokenNotFound)
}
func TestUserRepository_DeleteTokenByUserID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
_, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
err = repo.DeleteTokenByUserID(user.ID)
require.NoError(t, err)
// Token should be gone
var count int64
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count)
}
func TestUserRepository_CreateToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
token, err := repo.CreateToken(user.ID)
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Equal(t, user.ID, token.UserID)
}
func TestUserRepository_UpdateLastLogin(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
err := repo.UpdateLastLogin(user.ID)
require.NoError(t, err)
found, err := repo.FindByID(user.ID)
require.NoError(t, err)
assert.NotNil(t, found.LastLogin)
}
func TestUserRepository_UpdateProfile(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
profile, err := repo.GetOrCreateProfile(user.ID)
require.NoError(t, err)
profile.Bio = "Test bio"
profile.PhoneNumber = "+1-555-0100"
err = repo.UpdateProfile(profile)
require.NoError(t, err)
updated, err := repo.GetOrCreateProfile(user.ID)
require.NoError(t, err)
assert.Equal(t, "Test bio", updated.Bio)
assert.Equal(t, "+1-555-0100", updated.PhoneNumber)
}
func TestUserRepository_SetProfileVerified(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create profile
_, err := repo.GetOrCreateProfile(user.ID)
require.NoError(t, err)
err = repo.SetProfileVerified(user.ID, true)
require.NoError(t, err)
profile, err := repo.GetOrCreateProfile(user.ID)
require.NoError(t, err)
assert.True(t, profile.Verified)
}
func TestUserRepository_FindByIDWithProfile(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create profile first
profile := &models.UserProfile{
UserID: user.ID,
Bio: "My bio",
}
err := db.Create(profile).Error
require.NoError(t, err)
found, err := repo.FindByIDWithProfile(user.ID)
require.NoError(t, err)
assert.Equal(t, user.ID, found.ID)
require.NotNil(t, found.Profile)
assert.Equal(t, "My bio", found.Profile.Bio)
}
func TestUserRepository_FindByIDWithProfile_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByIDWithProfile(9999)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotFound)
}
func TestUserRepository_ConfirmationCode_Lifecycle(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create confirmation code
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreateConfirmationCode(user.ID, "123456", expiresAt)
require.NoError(t, err)
assert.NotZero(t, code.ID)
// Find it
found, err := repo.FindConfirmationCode(user.ID, "123456")
require.NoError(t, err)
assert.Equal(t, code.ID, found.ID)
// Mark as used
err = repo.MarkConfirmationCodeUsed(code.ID)
require.NoError(t, err)
// Should not find used code
_, err = repo.FindConfirmationCode(user.ID, "123456")
assert.Error(t, err)
}
func TestUserRepository_ConfirmationCode_InvalidatesExisting(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
// Create first code
code1, err := repo.CreateConfirmationCode(user.ID, "111111", expiresAt)
require.NoError(t, err)
// Create second code (should invalidate first)
_, err = repo.CreateConfirmationCode(user.ID, "222222", expiresAt)
require.NoError(t, err)
// First code should be used/invalidated
var c models.ConfirmationCode
db.First(&c, code1.ID)
assert.True(t, c.IsUsed)
}
func TestUserRepository_Transaction(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
err := repo.Transaction(func(txRepo *UserRepository) error {
found, err := txRepo.FindByID(user.ID)
if err != nil {
return err
}
found.FirstName = "Updated"
return txRepo.Update(found)
})
require.NoError(t, err)
found, err := repo.FindByID(user.ID)
require.NoError(t, err)
assert.Equal(t, "Updated", found.FirstName)
}
func TestUserRepository_DB(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
assert.NotNil(t, repo.DB())
}
func TestUserRepository_FindByAppleID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
appleAuth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_123",
Email: "apple@test.com",
}
require.NoError(t, db.Create(appleAuth).Error)
found, err := repo.FindByAppleID("apple_sub_123")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestUserRepository_FindByAppleID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByAppleID("nonexistent_apple_id")
assert.ErrorIs(t, err, ErrAppleAuthNotFound)
}
func TestUserRepository_FindByGoogleID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
googleAuth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_123",
Email: "google@test.com",
}
require.NoError(t, db.Create(googleAuth).Error)
found, err := repo.FindByGoogleID("google_sub_123")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestUserRepository_FindByGoogleID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByGoogleID("nonexistent_google_id")
assert.ErrorIs(t, err, ErrGoogleAuthNotFound)
}
func TestUserRepository_CreateAndUpdateAppleSocialAuth(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
auth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_456",
Email: "apple@test.com",
}
err := repo.CreateAppleSocialAuth(auth)
require.NoError(t, err)
assert.NotZero(t, auth.ID)
auth.Email = "updated@test.com"
err = repo.UpdateAppleSocialAuth(auth)
require.NoError(t, err)
found, err := repo.FindByAppleID("apple_sub_456")
require.NoError(t, err)
assert.Equal(t, "updated@test.com", found.Email)
}
func TestUserRepository_CreateAndUpdateGoogleSocialAuth(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
auth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_456",
Email: "google@test.com",
Name: "Test User",
}
err := repo.CreateGoogleSocialAuth(auth)
require.NoError(t, err)
assert.NotZero(t, auth.ID)
auth.Name = "Updated Name"
err = repo.UpdateGoogleSocialAuth(auth)
require.NoError(t, err)
found, err := repo.FindByGoogleID("google_sub_456")
require.NoError(t, err)
assert.Equal(t, "Updated Name", found.Name)
}
func TestUserRepository_SearchUsers(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "john_doe", "john@example.com", "Password123")
testutil.CreateTestUser(t, db, "jane_smith", "jane@example.com", "Password123")
testutil.CreateTestUser(t, db, "bob_builder", "bob@example.com", "Password123")
users, total, err := repo.SearchUsers("john", 10, 0)
require.NoError(t, err)
assert.Equal(t, int64(1), total)
assert.Len(t, users, 1)
assert.Equal(t, "john_doe", users[0].Username)
}
func TestUserRepository_ListUsers(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "user1", "user1@example.com", "Password123")
testutil.CreateTestUser(t, db, "user2", "user2@example.com", "Password123")
testutil.CreateTestUser(t, db, "user3", "user3@example.com", "Password123")
users, total, err := repo.ListUsers(2, 0)
require.NoError(t, err)
assert.Equal(t, int64(3), total)
assert.Len(t, users, 2) // Limited to 2
users, total, err = repo.ListUsers(2, 2)
require.NoError(t, err)
assert.Equal(t, int64(3), total)
assert.Len(t, users, 1) // Only 1 remaining
}

View File

@@ -0,0 +1,367 @@
package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
// === Password Reset Code Lifecycle ===
func TestUserRepository_PasswordResetCode_Lifecycle(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_abc123", "reset_token_xyz", expiresAt)
require.NoError(t, err)
assert.NotZero(t, code.ID)
assert.Equal(t, "hash_abc123", code.CodeHash)
assert.Equal(t, "reset_token_xyz", code.ResetToken)
assert.False(t, code.Used)
assert.Equal(t, 0, code.Attempts)
}
func TestUserRepository_CreatePasswordResetCode_InvalidatesExisting(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code1, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
require.NoError(t, err)
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
require.NoError(t, err)
// First code should be marked as used
var c models.PasswordResetCode
db.First(&c, code1.ID)
assert.True(t, c.Used)
}
func TestUserRepository_FindPasswordResetCodeByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash_abc", "token_abc", expiresAt)
require.NoError(t, err)
found, foundUser, err := repo.FindPasswordResetCodeByEmail("test@example.com")
require.NoError(t, err)
assert.Equal(t, user.ID, foundUser.ID)
assert.Equal(t, "hash_abc", found.CodeHash)
}
func TestUserRepository_FindPasswordResetCodeByEmail_UserNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, _, err := repo.FindPasswordResetCodeByEmail("nonexistent@example.com")
assert.Error(t, err)
}
func TestUserRepository_FindPasswordResetCodeByEmail_NoCode(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
_, _, err := repo.FindPasswordResetCodeByEmail("test@example.com")
assert.ErrorIs(t, err, ErrCodeNotFound)
}
func TestUserRepository_FindPasswordResetCodeByToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash_xyz", "token_xyz", expiresAt)
require.NoError(t, err)
found, err := repo.FindPasswordResetCodeByToken("token_xyz")
require.NoError(t, err)
assert.Equal(t, "hash_xyz", found.CodeHash)
}
func TestUserRepository_FindPasswordResetCodeByToken_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindPasswordResetCodeByToken("nonexistent_token")
assert.ErrorIs(t, err, ErrCodeNotFound)
}
func TestUserRepository_FindPasswordResetCodeByToken_Expired(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Already expired
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash_exp", "token_exp", expiresAt)
require.NoError(t, err)
_, err = repo.FindPasswordResetCodeByToken("token_exp")
assert.ErrorIs(t, err, ErrCodeExpired)
}
func TestUserRepository_FindPasswordResetCodeByToken_Used(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_used", "token_used", expiresAt)
require.NoError(t, err)
// Mark as used
err = repo.MarkPasswordResetCodeUsed(code.ID)
require.NoError(t, err)
_, err = repo.FindPasswordResetCodeByToken("token_used")
assert.ErrorIs(t, err, ErrCodeUsed)
}
func TestUserRepository_FindPasswordResetCodeByToken_TooManyAttempts(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_attempts", "token_attempts", expiresAt)
require.NoError(t, err)
// Max out attempts
for i := 0; i < 5; i++ {
err = repo.IncrementResetCodeAttempts(code.ID)
require.NoError(t, err)
}
_, err = repo.FindPasswordResetCodeByToken("token_attempts")
assert.ErrorIs(t, err, ErrTooManyAttempts)
}
func TestUserRepository_IncrementResetCodeAttempts(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_inc", "token_inc", expiresAt)
require.NoError(t, err)
err = repo.IncrementResetCodeAttempts(code.ID)
require.NoError(t, err)
var updated models.PasswordResetCode
db.First(&updated, code.ID)
assert.Equal(t, 1, updated.Attempts)
}
func TestUserRepository_MarkPasswordResetCodeUsed(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_mark", "token_mark", expiresAt)
require.NoError(t, err)
err = repo.MarkPasswordResetCodeUsed(code.ID)
require.NoError(t, err)
var updated models.PasswordResetCode
db.First(&updated, code.ID)
assert.True(t, updated.Used)
}
func TestUserRepository_CountRecentPasswordResetRequests(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
require.NoError(t, err)
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
require.NoError(t, err)
count, err := repo.CountRecentPasswordResetRequests(user.ID)
require.NoError(t, err)
assert.Equal(t, int64(2), count)
}
// === FindUsersInSharedResidences ===
func TestUserRepository_FindUsersInSharedResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := NewUserRepository(db)
resRepo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
unrelated := testutil.CreateTestUser(t, db, "unrelated", "unrelated@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Shared House")
resRepo.AddUser(residence.ID, shared.ID)
// Owner should see shared user
users, err := userRepo.FindUsersInSharedResidences(owner.ID)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, shared.ID, users[0].ID)
// Shared user should see owner
users, err = userRepo.FindUsersInSharedResidences(shared.ID)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, owner.ID, users[0].ID)
// Unrelated should see no one
users, err = userRepo.FindUsersInSharedResidences(unrelated.ID)
require.NoError(t, err)
assert.Empty(t, users)
}
// === FindUserIfSharedResidence ===
func TestUserRepository_FindUserIfSharedResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := NewUserRepository(db)
resRepo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
unrelated := testutil.CreateTestUser(t, db, "unrelated", "unrelated@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Shared House")
resRepo.AddUser(residence.ID, shared.ID)
// Owner requesting shared user => should find
found, err := userRepo.FindUserIfSharedResidence(shared.ID, owner.ID)
require.NoError(t, err)
require.NotNil(t, found)
assert.Equal(t, shared.ID, found.ID)
// Unrelated requesting shared user => should not find
found, err = userRepo.FindUserIfSharedResidence(shared.ID, unrelated.ID)
require.NoError(t, err)
assert.Nil(t, found)
// Requesting self => should work
found, err = userRepo.FindUserIfSharedResidence(owner.ID, owner.ID)
require.NoError(t, err)
require.NotNil(t, found)
assert.Equal(t, owner.ID, found.ID)
}
// === FindProfilesInSharedResidences ===
func TestUserRepository_FindProfilesInSharedResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := NewUserRepository(db)
resRepo := NewResidenceRepository(db)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
// Create profiles
ownerProfile := &models.UserProfile{UserID: owner.ID, Bio: "Owner bio"}
sharedProfile := &models.UserProfile{UserID: shared.ID, Bio: "Shared bio"}
require.NoError(t, db.Create(ownerProfile).Error)
require.NoError(t, db.Create(sharedProfile).Error)
residence := testutil.CreateTestResidence(t, db, owner.ID, "Shared House")
resRepo.AddUser(residence.ID, shared.ID)
// Owner sees own profile + shared user profile
profiles, err := userRepo.FindProfilesInSharedResidences(owner.ID)
require.NoError(t, err)
assert.Len(t, profiles, 2)
}
// === ConfirmationCode Expired ===
func TestUserRepository_FindConfirmationCode_Expired(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create already-expired code
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
_, err := repo.CreateConfirmationCode(user.ID, "999999", expiresAt)
require.NoError(t, err)
_, err = repo.FindConfirmationCode(user.ID, "999999")
assert.ErrorIs(t, err, ErrCodeExpired)
}
func TestUserRepository_FindConfirmationCode_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
_, err := repo.FindConfirmationCode(user.ID, "000000")
assert.ErrorIs(t, err, ErrCodeNotFound)
}
// === Transaction Rollback ===
func TestUserRepository_Transaction_Rollback(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
err := repo.Transaction(func(txRepo *UserRepository) error {
found, err := txRepo.FindByID(user.ID)
if err != nil {
return err
}
found.FirstName = "ShouldRollback"
if err := txRepo.Update(found); err != nil {
return err
}
// Simulate an error to trigger rollback
return ErrUserNotFound
})
assert.Error(t, err)
// Name should NOT have been updated
found, err := repo.FindByID(user.ID)
require.NoError(t, err)
assert.NotEqual(t, "ShouldRollback", found.FirstName)
}
// === FindByUsernameOrEmail not found ===
func TestUserRepository_FindByUsernameOrEmail_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByUsernameOrEmail("nonexistent")
assert.ErrorIs(t, err, ErrUserNotFound)
}

View File

@@ -19,7 +19,7 @@ func TestUserRepository_Create(t *testing.T) {
Email: "test@example.com",
IsActive: true,
}
user.SetPassword("password123")
user.SetPassword("Password123")
err := repo.Create(user)
require.NoError(t, err)
@@ -31,7 +31,7 @@ func TestUserRepository_FindByID(t *testing.T) {
repo := NewUserRepository(db)
// Create user
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Find by ID
found, err := repo.FindByID(user.ID)
@@ -54,7 +54,7 @@ func TestUserRepository_FindByUsername(t *testing.T) {
repo := NewUserRepository(db)
// Create user
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Find by username
found, err := repo.FindByUsername("testuser")
@@ -67,7 +67,7 @@ func TestUserRepository_FindByEmail(t *testing.T) {
repo := NewUserRepository(db)
// Create user
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Find by email
found, err := repo.FindByEmail("test@example.com")
@@ -80,7 +80,7 @@ func TestUserRepository_FindByUsernameOrEmail(t *testing.T) {
repo := NewUserRepository(db)
// Create user
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
tests := []struct {
name string
@@ -105,7 +105,7 @@ func TestUserRepository_Update(t *testing.T) {
repo := NewUserRepository(db)
// Create user
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Update user
user.FirstName = "John"
@@ -125,7 +125,7 @@ func TestUserRepository_ExistsByUsername(t *testing.T) {
repo := NewUserRepository(db)
// Create user
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "password123")
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "Password123")
tests := []struct {
name string
@@ -150,7 +150,7 @@ func TestUserRepository_ExistsByEmail(t *testing.T) {
repo := NewUserRepository(db)
// Create user
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "password123")
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "Password123")
tests := []struct {
name string
@@ -175,7 +175,7 @@ func TestUserRepository_GetOrCreateProfile(t *testing.T) {
repo := NewUserRepository(db)
// Create user
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// First call should create
profile1, err := repo.GetOrCreateProfile(user.ID)
@@ -193,14 +193,14 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
repo := NewUserRepository(db)
t.Run("email user", func(t *testing.T) {
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "password123")
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123")
provider, err := repo.FindAuthProvider(user.ID)
require.NoError(t, err)
assert.Equal(t, "email", provider)
})
t.Run("apple user", func(t *testing.T) {
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "password123")
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
appleAuth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_test",
@@ -214,7 +214,7 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
})
t.Run("google user", func(t *testing.T) {
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "password123")
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
googleAuth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_test",
@@ -233,7 +233,7 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "password123")
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123")
// Create profile and token
profile := &models.UserProfile{UserID: user.ID, Verified: true}
@@ -271,7 +271,7 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "deletefiles", "deletefiles@test.com", "password123")
user := testutil.CreateTestUser(t, db, "deletefiles", "deletefiles@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test Home")
// Create document with file
@@ -319,8 +319,8 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
owner := testutil.CreateTestUser(t, db, "deleteowner", "deleteowner@test.com", "password123")
otherUser := testutil.CreateTestUser(t, db, "otheruser", "other@test.com", "password123")
owner := testutil.CreateTestUser(t, db, "deleteowner", "deleteowner@test.com", "Password123")
otherUser := testutil.CreateTestUser(t, db, "otheruser", "other@test.com", "Password123")
otherResidence := testutil.CreateTestResidence(t, db, otherUser.ID, "Other Home")
// Owner's residence

View File

@@ -0,0 +1,29 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEscapeLikeWildcards(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"no wildcards", "hello", "hello"},
{"percent sign", "50% off", "50\\% off"},
{"underscore", "user_name", "user\\_name"},
{"both wildcards", "50%_off", "50\\%\\_off"},
{"empty string", "", ""},
{"only wildcards", "%_", "\\%\\_"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := escapeLikeWildcards(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1,262 @@
package router
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/labstack/echo/v4"
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/dto/responses"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/services"
)
func TestMain(m *testing.M) {
// Initialize i18n so LocalizedMessage returns real translations
_ = i18n.Init()
os.Exit(m.Run())
}
// makeContext creates a fresh echo.Context and response recorder for testing.
func makeContext(method, path string) (echo.Context, *httptest.ResponseRecorder) {
e := echo.New()
req := httptest.NewRequest(method, path, nil)
req.Header.Set("Accept-Language", "en")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Set up i18n localizer on context (same key the middleware uses)
if i18n.Bundle != nil {
loc := i18n.NewLocalizer("en")
c.Set(i18n.LocalizerKey, loc)
}
return c, rec
}
// decodeError reads the JSON response into an ErrorResponse.
func decodeError(t *testing.T, rec *httptest.ResponseRecorder) responses.ErrorResponse {
t.Helper()
var resp responses.ErrorResponse
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode error response: %v\nbody: %s", err, rec.Body.String())
}
return resp
}
// --- AppError branch ---
func TestErrorHandler_AppError_NotFound(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/tasks/1")
err := apperrors.NotFound("error.task_not_found")
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
resp := decodeError(t, rec)
if resp.Error == "" {
t.Error("expected non-empty error message")
}
}
func TestErrorHandler_AppError_Forbidden(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/res/1")
err := apperrors.Forbidden("error.task_access_denied")
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusForbidden {
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func TestErrorHandler_AppError_BadRequest(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/tasks")
err := apperrors.BadRequest("error.task_already_cancelled")
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusBadRequest {
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestErrorHandler_AppError_WithWrappedErr(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/x")
err := apperrors.Internal(fmt.Errorf("db conn failed"))
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusInternalServerError {
t.Errorf("code = %d, want %d", rec.Code, http.StatusInternalServerError)
}
}
func TestErrorHandler_AppError_I18nFallback(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/x")
err := apperrors.NotFound("nonexistent.i18n.key").WithMessage("fallback msg")
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
resp := decodeError(t, rec)
if resp.Error != "fallback msg" {
t.Errorf("error = %q, want %q", resp.Error, "fallback msg")
}
}
// --- Echo HTTPError branch ---
func TestErrorHandler_EchoHTTPError_404(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/nope")
err := echo.NewHTTPError(http.StatusNotFound, "not found")
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
resp := decodeError(t, rec)
if resp.Error != "not found" {
t.Errorf("error = %q, want %q", resp.Error, "not found")
}
}
func TestErrorHandler_EchoHTTPError_405(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/x")
err := echo.ErrMethodNotAllowed
customHTTPErrorHandler(err, c)
if rec.Code != http.StatusMethodNotAllowed {
t.Errorf("code = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
}
// --- Sentinel error branch ---
func TestErrorHandler_ErrTaskNotFound(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/t")
customHTTPErrorHandler(services.ErrTaskNotFound, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestErrorHandler_ErrCompletionNotFound(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/t")
customHTTPErrorHandler(services.ErrCompletionNotFound, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestErrorHandler_ErrTaskAccessDenied(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/t")
customHTTPErrorHandler(services.ErrTaskAccessDenied, c)
if rec.Code != http.StatusForbidden {
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func TestErrorHandler_ErrTaskAlreadyCancelled(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/t")
customHTTPErrorHandler(services.ErrTaskAlreadyCancelled, c)
if rec.Code != http.StatusBadRequest {
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestErrorHandler_ErrTaskAlreadyArchived(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/t")
customHTTPErrorHandler(services.ErrTaskAlreadyArchived, c)
if rec.Code != http.StatusBadRequest {
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestErrorHandler_ErrResidenceNotFound(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/r")
customHTTPErrorHandler(services.ErrResidenceNotFound, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestErrorHandler_ErrResidenceAccessDenied(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/r")
customHTTPErrorHandler(services.ErrResidenceAccessDenied, c)
if rec.Code != http.StatusForbidden {
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func TestErrorHandler_ErrNotResidenceOwner(t *testing.T) {
c, rec := makeContext(http.MethodDelete, "/r")
customHTTPErrorHandler(services.ErrNotResidenceOwner, c)
if rec.Code != http.StatusForbidden {
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func TestErrorHandler_ErrPropertiesLimitReached(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/r")
customHTTPErrorHandler(services.ErrPropertiesLimitReached, c)
if rec.Code != http.StatusForbidden {
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func TestErrorHandler_ErrCannotRemoveOwner(t *testing.T) {
c, rec := makeContext(http.MethodDelete, "/r")
customHTTPErrorHandler(services.ErrCannotRemoveOwner, c)
if rec.Code != http.StatusBadRequest {
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestErrorHandler_ErrShareCodeExpired(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/r")
customHTTPErrorHandler(services.ErrShareCodeExpired, c)
if rec.Code != http.StatusBadRequest {
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestErrorHandler_ErrShareCodeInvalid(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/r")
customHTTPErrorHandler(services.ErrShareCodeInvalid, c)
if rec.Code != http.StatusNotFound {
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestErrorHandler_ErrUserAlreadyMember(t *testing.T) {
c, rec := makeContext(http.MethodPost, "/r")
customHTTPErrorHandler(services.ErrUserAlreadyMember, c)
if rec.Code != http.StatusConflict {
t.Errorf("code = %d, want %d", rec.Code, http.StatusConflict)
}
}
// --- Default branch ---
func TestErrorHandler_UnknownError(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/x")
customHTTPErrorHandler(errors.New("random unexpected error"), c)
if rec.Code != http.StatusInternalServerError {
t.Errorf("code = %d, want %d", rec.Code, http.StatusInternalServerError)
}
}
// --- Committed response ---
func TestErrorHandler_CommittedResponse_Noop(t *testing.T) {
c, rec := makeContext(http.MethodGet, "/x")
// Write something to commit the response
c.JSON(http.StatusOK, map[string]string{"ok": "true"})
// Capture the body after first write
bodyBefore := rec.Body.String()
// Now call error handler — it should be a no-op
customHTTPErrorHandler(errors.New("should be ignored"), c)
bodyAfter := rec.Body.String()
if bodyBefore != bodyAfter {
t.Errorf("body changed after committed response:\nbefore: %s\nafter: %s", bodyBefore, bodyAfter)
}
}

View File

@@ -0,0 +1,115 @@
package router
import (
"fmt"
"strings"
"github.com/treytartt/honeydue-api/internal/monitoring"
)
// CorsOrigins returns the CORS allowed origins based on debug mode.
func CorsOrigins(debug bool, configuredOrigins []string) []string {
if debug {
return []string{
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:8080",
"http://localhost:8000",
"http://127.0.0.1:3000",
"http://127.0.0.1:3001",
"http://127.0.0.1:8080",
"http://127.0.0.1:8000",
}
}
if len(configuredOrigins) > 0 {
return configuredOrigins
}
return []string{
"https://api.myhoneydue.com",
"https://myhoneydue.com",
"https://admin.myhoneydue.com",
}
}
// ShouldSkipTimeout returns true for paths that should bypass timeout middleware.
func ShouldSkipTimeout(path, host, adminHost string) bool {
return (adminHost != "" && host == adminHost) ||
strings.HasPrefix(path, "/_next") ||
strings.HasSuffix(path, "/ws")
}
// ShouldSkipBodyLimit returns true for webhook endpoints.
func ShouldSkipBodyLimit(path string) bool {
return strings.HasPrefix(path, "/api/subscription/webhook")
}
// ShouldSkipGzip returns true for media endpoints.
func ShouldSkipGzip(path string) bool {
return strings.HasPrefix(path, "/api/media/")
}
// ParseEndpoint splits "GET /api/foo" into method and path.
func ParseEndpoint(endpoint string) (method, path string) {
parts := strings.SplitN(endpoint, " ", 2)
if len(parts) == 2 {
return parts[0], parts[1]
}
return endpoint, ""
}
// AllowedProxyHosts builds the list of hosts allowed for admin proxy.
func AllowedProxyHosts(adminHost string) []string {
var hosts []string
if adminHost != "" {
hosts = append(hosts, adminHost)
}
hosts = append(hosts, "localhost:3001", "127.0.0.1:3001", "localhost:8000", "127.0.0.1:8000")
return hosts
}
// DetermineAdminRoute decides how to route admin subdomain requests.
func DetermineAdminRoute(path string) string {
if strings.HasPrefix(path, "/admin") {
return "redirect"
}
if strings.HasPrefix(path, "/api/") {
return "passthrough"
}
return "proxy"
}
// FormatPrometheusMetrics converts HTTP stats to Prometheus text format.
func FormatPrometheusMetrics(stats monitoring.HTTPStats) string {
var b strings.Builder
b.WriteString("# HELP http_requests_total Total number of HTTP requests.\n")
b.WriteString("# TYPE http_requests_total counter\n")
for statusCode, count := range stats.ByStatusCode {
fmt.Fprintf(&b, "http_requests_total{status_code=\"%d\"} %d\n", statusCode, count)
}
b.WriteString("# HELP http_endpoint_requests_total Total requests per endpoint.\n")
b.WriteString("# TYPE http_endpoint_requests_total counter\n")
for endpoint, epStats := range stats.ByEndpoint {
method, path := ParseEndpoint(endpoint)
fmt.Fprintf(&b, "http_endpoint_requests_total{method=\"%s\",path=\"%s\"} %d\n", method, path, epStats.Count)
}
b.WriteString("# HELP http_request_duration_ms Average request duration in milliseconds per endpoint.\n")
b.WriteString("# TYPE http_request_duration_ms gauge\n")
for endpoint, epStats := range stats.ByEndpoint {
method, path := ParseEndpoint(endpoint)
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"avg\"} %.2f\n", method, path, epStats.AvgLatencyMs)
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"p95\"} %.2f\n", method, path, epStats.P95LatencyMs)
}
b.WriteString("# HELP http_error_rate Overall error rate (4xx+5xx / total).\n")
b.WriteString("# TYPE http_error_rate gauge\n")
fmt.Fprintf(&b, "http_error_rate %.4f\n", stats.ErrorRate)
b.WriteString("# HELP http_requests_per_minute Current request rate.\n")
b.WriteString("# TYPE http_requests_per_minute gauge\n")
fmt.Fprintf(&b, "http_requests_per_minute %.2f\n", stats.RequestsPerMinute)
return b.String()
}

View File

@@ -0,0 +1,200 @@
package router
import (
"strings"
"testing"
"github.com/treytartt/honeydue-api/internal/monitoring"
)
// --- CorsOrigins ---
func TestCorsOrigins_Debug(t *testing.T) {
origins := CorsOrigins(true, nil)
if len(origins) != 8 {
t.Errorf("len = %d, want 8", len(origins))
}
found := false
for _, o := range origins {
if o == "http://localhost:3000" {
found = true
}
}
if !found {
t.Error("expected localhost:3000 in debug origins")
}
}
func TestCorsOrigins_ProductionConfigured(t *testing.T) {
custom := []string{"https://example.com"}
origins := CorsOrigins(false, custom)
if len(origins) != 1 || origins[0] != "https://example.com" {
t.Errorf("got %v, want [https://example.com]", origins)
}
}
func TestCorsOrigins_ProductionDefault(t *testing.T) {
origins := CorsOrigins(false, nil)
if len(origins) != 3 {
t.Errorf("len = %d, want 3", len(origins))
}
found := false
for _, o := range origins {
if o == "https://myhoneydue.com" {
found = true
}
}
if !found {
t.Error("expected myhoneydue.com in default origins")
}
}
// --- ShouldSkipTimeout ---
func TestShouldSkipTimeout_AdminHost_True(t *testing.T) {
if !ShouldSkipTimeout("/some/path", "admin.example.com", "admin.example.com") {
t.Error("expected true for admin host")
}
}
func TestShouldSkipTimeout_NextPath_True(t *testing.T) {
if !ShouldSkipTimeout("/_next/static/chunk.js", "app.example.com", "admin.example.com") {
t.Error("expected true for _next path")
}
}
func TestShouldSkipTimeout_WsPath_True(t *testing.T) {
if !ShouldSkipTimeout("/api/events/ws", "app.example.com", "admin.example.com") {
t.Error("expected true for /ws path")
}
}
func TestShouldSkipTimeout_NormalPath_False(t *testing.T) {
if ShouldSkipTimeout("/api/tasks/", "app.example.com", "admin.example.com") {
t.Error("expected false for normal path")
}
}
func TestShouldSkipTimeout_EmptyAdminHost_False(t *testing.T) {
if ShouldSkipTimeout("/api/tasks/", "admin.example.com", "") {
t.Error("expected false when admin host is empty")
}
}
// --- ShouldSkipBodyLimit ---
func TestShouldSkipBodyLimit_Webhook_True(t *testing.T) {
if !ShouldSkipBodyLimit("/api/subscription/webhook/apple/") {
t.Error("expected true for webhook path")
}
}
func TestShouldSkipBodyLimit_Normal_False(t *testing.T) {
if ShouldSkipBodyLimit("/api/tasks/") {
t.Error("expected false for normal path")
}
}
// --- ShouldSkipGzip ---
func TestShouldSkipGzip_Media_True(t *testing.T) {
if !ShouldSkipGzip("/api/media/document/123") {
t.Error("expected true for media path")
}
}
func TestShouldSkipGzip_Api_False(t *testing.T) {
if ShouldSkipGzip("/api/tasks/") {
t.Error("expected false for non-media path")
}
}
// --- ParseEndpoint ---
func TestParseEndpoint_MethodAndPath(t *testing.T) {
method, path := ParseEndpoint("GET /api/tasks/")
if method != "GET" || path != "/api/tasks/" {
t.Errorf("got (%q, %q), want (GET, /api/tasks/)", method, path)
}
}
func TestParseEndpoint_NoSpace(t *testing.T) {
method, path := ParseEndpoint("GET")
if method != "GET" || path != "" {
t.Errorf("got (%q, %q), want (GET, \"\")", method, path)
}
}
// --- AllowedProxyHosts ---
func TestAllowedProxyHosts_WithAdmin(t *testing.T) {
hosts := AllowedProxyHosts("admin.example.com")
if len(hosts) != 5 {
t.Errorf("len = %d, want 5", len(hosts))
}
if hosts[0] != "admin.example.com" {
t.Errorf("first host = %q, want admin.example.com", hosts[0])
}
}
func TestAllowedProxyHosts_WithoutAdmin(t *testing.T) {
hosts := AllowedProxyHosts("")
if len(hosts) != 4 {
t.Errorf("len = %d, want 4", len(hosts))
}
}
// --- DetermineAdminRoute ---
func TestDetermineAdminRoute_Admin_Redirect(t *testing.T) {
got := DetermineAdminRoute("/admin/dashboard")
if got != "redirect" {
t.Errorf("got %q, want redirect", got)
}
}
func TestDetermineAdminRoute_Api_Passthrough(t *testing.T) {
got := DetermineAdminRoute("/api/tasks/")
if got != "passthrough" {
t.Errorf("got %q, want passthrough", got)
}
}
func TestDetermineAdminRoute_Other_Proxy(t *testing.T) {
got := DetermineAdminRoute("/dashboard")
if got != "proxy" {
t.Errorf("got %q, want proxy", got)
}
}
// --- FormatPrometheusMetrics ---
func TestFormatPrometheusMetrics_Output(t *testing.T) {
stats := monitoring.HTTPStats{
RequestsTotal: 100,
RequestsPerMinute: 10.5,
ErrorRate: 0.05,
ByStatusCode: map[int]int64{200: 90, 500: 10},
ByEndpoint: map[string]monitoring.EndpointStats{
"GET /api/tasks/": {Count: 50, AvgLatencyMs: 12.5, P95LatencyMs: 45.0},
},
}
output := FormatPrometheusMetrics(stats)
// Check key sections exist
checks := []string{
"http_requests_total",
"http_endpoint_requests_total",
"http_request_duration_ms",
"http_error_rate",
"http_requests_per_minute",
`method="GET"`,
`path="/api/tasks/"`,
}
for _, check := range checks {
if !strings.Contains(output, check) {
t.Errorf("output missing %q", check)
}
}
}

View File

@@ -35,7 +35,7 @@ func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
Email: "refresh@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("password123"))
require.NoError(t, user.SetPassword("Password123"))
require.NoError(t, db.Create(user).Error)
return user
}

View File

@@ -0,0 +1,800 @@
package services
import (
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/dto/requests"
"github.com/treytartt/honeydue-api/internal/repositories"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
notifRepo := repositories.NewNotificationRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{
DebugFixedCodes: true,
},
Security: config.SecurityConfig{
SecretKey: "test-secret",
ConfirmationExpiry: 24 * time.Hour,
PasswordResetExpiry: 15 * time.Minute,
MaxPasswordResetRate: 3,
TokenExpiryDays: 90,
TokenRefreshDays: 60,
},
}
service := NewAuthService(userRepo, cfg)
service.SetNotificationRepository(notifRepo)
return service, userRepo
}
// === Login ===
func TestAuthService_Login(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "Password123",
}
resp, err := service.Login(req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "testuser", resp.User.Username)
}
func TestAuthService_Login_ByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Email: "test@test.com",
Password: "Password123",
}
resp, err := service.Login(req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
}
func TestAuthService_Login_InvalidCredentials(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "WrongPassword1",
}
_, err := service.Login(req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_Login_UserNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
req := &requests.LoginRequest{
Username: "nonexistent",
Password: "Password123",
}
_, err := service.Login(req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_Login_InactiveUser(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123")
// Deactivate
user.IsActive = false
db.Save(user)
req := &requests.LoginRequest{
Username: "inactive",
Password: "Password123",
}
_, err := service.Login(req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
}
// === Register ===
func TestAuthService_Register(t *testing.T) {
service, _ := setupAuthService(t)
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
resp, code, err := service.Register(req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "newuser", resp.User.Username)
assert.Equal(t, "123456", code) // DebugFixedCodes=true
}
func TestAuthService_Register_DuplicateUsername(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{DebugFixedCodes: true},
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123")
req := &requests.RegisterRequest{
Username: "taken",
Email: "different@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
}
func TestAuthService_Register_DuplicateEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{DebugFixedCodes: true},
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123")
req := &requests.RegisterRequest{
Username: "newuser",
Email: "taken@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
}
// === GetCurrentUser ===
func TestAuthService_GetCurrentUser(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
// Create profile
userRepo.GetOrCreateProfile(user.ID)
resp, err := service.GetCurrentUser(user.ID)
require.NoError(t, err)
assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@test.com", resp.Email)
assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth
}
// === UpdateProfile ===
func TestAuthService_UpdateProfile(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
userRepo.GetOrCreateProfile(user.ID)
newFirst := "John"
newLast := "Doe"
req := &requests.UpdateProfileRequest{
FirstName: &newFirst,
LastName: &newLast,
}
resp, err := service.UpdateProfile(user.ID, req)
require.NoError(t, err)
assert.Equal(t, "John", resp.FirstName)
assert.Equal(t, "Doe", resp.LastName)
}
func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123")
user2 := testutil.CreateTestUser(t, db, "user2", "user2@test.com", "Password123")
userRepo.GetOrCreateProfile(user2.ID)
takenEmail := "user1@test.com"
req := &requests.UpdateProfileRequest{
Email: &takenEmail,
}
_, err := service.UpdateProfile(user2.ID, req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken")
}
func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
userRepo.GetOrCreateProfile(user.ID)
sameEmail := "test@test.com"
req := &requests.UpdateProfileRequest{
Email: &sameEmail,
}
// Same email should not trigger duplicate error
resp, err := service.UpdateProfile(user.ID, req)
require.NoError(t, err)
assert.Equal(t, "test@test.com", resp.Email)
}
// === VerifyEmail ===
func TestAuthService_VerifyEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Register a user (creates confirmation code)
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
require.NoError(t, err)
// Get the user ID
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
// Verify with the debug code
err = service.VerifyEmail(user.ID, "123456")
require.NoError(t, err)
// Verify again — should get already verified error
err = service.VerifyEmail(user.ID, "123456")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
service, _ := setupAuthService(t)
// Register
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
// but a wrong code should use the normal path
err = service.VerifyEmail(user.ID, "000000")
assert.Error(t, err)
}
// === ResendVerificationCode ===
func TestAuthService_ResendVerificationCode(t *testing.T) {
service, _ := setupAuthService(t)
// Register
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
code, err := service.ResendVerificationCode(user.ID)
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
}
func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
service, _ := setupAuthService(t)
// Register and verify
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
err = service.VerifyEmail(user.ID, "123456")
require.NoError(t, err)
_, err = service.ResendVerificationCode(user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
// === ForgotPassword ===
func TestAuthService_ForgotPassword(t *testing.T) {
service, _ := setupAuthService(t)
// Register a user first
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
code, user, err := service.ForgotPassword("test@test.com")
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
assert.NotNil(t, user)
assert.Equal(t, "test@test.com", user.Email)
}
func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Should not reveal that email doesn't exist
code, user, err := service.ForgotPassword("nonexistent@test.com")
require.NoError(t, err)
assert.Empty(t, code)
assert.Nil(t, user)
}
// === ResetPassword ===
func TestAuthService_ResetPassword(t *testing.T) {
service, _ := setupAuthService(t)
// Register
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
// Forgot password
_, _, err = service.ForgotPassword("test@test.com")
require.NoError(t, err)
// Verify reset code to get the token
resetToken, err := service.VerifyResetCode("test@test.com", "123456")
require.NoError(t, err)
assert.NotEmpty(t, resetToken)
// Reset password
err = service.ResetPassword(resetToken, "NewPassword123")
require.NoError(t, err)
// Login with new password
loginReq := &requests.LoginRequest{
Username: "testuser",
Password: "NewPassword123",
}
loginResp, err := service.Login(loginReq)
require.NoError(t, err)
assert.NotEmpty(t, loginResp.Token)
}
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
service, _ := setupAuthService(t)
err := service.ResetPassword("invalid-token", "NewPassword123")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
}
// === Logout ===
func TestAuthService_Logout(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
// Login first
loginReq := &requests.LoginRequest{
Username: "testuser",
Password: "Password123",
}
loginResp, err := service.Login(loginReq)
require.NoError(t, err)
// Logout
err = service.Logout(loginResp.Token)
require.NoError(t, err)
// Token should be deleted — refreshing should fail
_, err = service.RefreshToken(loginResp.Token, user.ID)
assert.Error(t, err)
}
// === DeleteAccount ===
func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
service, _ := setupAuthService(t)
// Register
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
password := "Password123"
_, err = service.DeleteAccount(user.ID, &password, nil)
require.NoError(t, err)
}
func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
wrongPassword := "WrongPassword1"
_, err = service.DeleteAccount(user.ID, &wrongPassword, nil)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
_, err = service.DeleteAccount(user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
}
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t)
password := "Password123"
_, err := service.DeleteAccount(99999, &password, nil)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
}
// === Helper functions ===
func TestGenerateSixDigitCode(t *testing.T) {
code := generateSixDigitCode()
assert.Len(t, code, 6)
// Should be numeric
for _, c := range code {
assert.True(t, c >= '0' && c <= '9', "code should contain only digits")
}
}
func TestGenerateResetToken(t *testing.T) {
token := generateResetToken()
assert.NotEmpty(t, token)
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
}
func TestGetStringOrEmpty(t *testing.T) {
s := "hello"
assert.Equal(t, "hello", getStringOrEmpty(&s))
assert.Equal(t, "", getStringOrEmpty(nil))
}
func TestIsPrivateRelayEmail(t *testing.T) {
assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com"))
assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM"))
assert.False(t, isPrivateRelayEmail("user@gmail.com"))
}
func TestGetEmailFromRequest(t *testing.T) {
email := "req@test.com"
assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com"))
assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com"))
empty := ""
assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com"))
}
// === getEmailOrDefault ===
func TestGetEmailOrDefault(t *testing.T) {
// Non-empty email returns itself
assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com"))
// Empty email returns a generated placeholder
result := getEmailOrDefault("")
assert.Contains(t, result, "@privaterelay.appleid.com")
assert.Contains(t, result, "apple_")
}
// === generateUniqueUsername ===
func TestGenerateUniqueUsername(t *testing.T) {
// Normal email generates username from email prefix
username := generateUniqueUsername("john@test.com", nil)
assert.Contains(t, username, "john_")
// Private relay email falls back to first name
firstName := "Jane"
username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName)
assert.Contains(t, username, "jane_")
// Private relay email and no first name — fallback
username = generateUniqueUsername("abc@privaterelay.appleid.com", nil)
assert.Contains(t, username, "user_")
// Empty email with first name
firstName2 := "Bob"
username = generateUniqueUsername("", &firstName2)
assert.Contains(t, username, "bob_")
// Empty email and no first name
username = generateUniqueUsername("", nil)
assert.Contains(t, username, "user_")
}
// === generateGoogleUsername ===
func TestGenerateGoogleUsername(t *testing.T) {
// Normal email
username := generateGoogleUsername("john@gmail.com", "John")
assert.Contains(t, username, "john_")
// Empty email falls back to first name
username = generateGoogleUsername("", "Alice")
assert.Contains(t, username, "alice_")
// Empty email and empty first name — fallback
username = generateGoogleUsername("", "")
assert.Contains(t, username, "google_")
}
// === Login with empty password ===
func TestAuthService_Login_EmptyPassword(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "",
}
_, err := service.Login(req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
// === ForgotPassword rate limiting ===
func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
// Make max allowed reset requests (3 based on setup)
for i := 0; i < 3; i++ {
_, _, err := service.ForgotPassword("test@test.com")
require.NoError(t, err)
}
// The 4th should be rate limited
_, _, err = service.ForgotPassword("test@test.com")
assert.Error(t, err)
}
// === VerifyResetCode with wrong code ===
func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
_, _, err = service.ForgotPassword("test@test.com")
require.NoError(t, err)
// Wrong code but with debug mode, "123456" works, "000000" should fail
_, err = service.VerifyResetCode("test@test.com", "000000")
assert.Error(t, err)
}
// === VerifyResetCode with nonexistent email ===
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
_, err := service.VerifyResetCode("nonexistent@test.com", "123456")
assert.Error(t, err)
}
// === UpdateProfile — change email to new email ===
func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
userRepo.GetOrCreateProfile(user.ID)
newEmail := "newemail@test.com"
req := &requests.UpdateProfileRequest{
Email: &newEmail,
}
resp, err := service.UpdateProfile(user.ID, req)
require.NoError(t, err)
assert.Equal(t, "newemail@test.com", resp.Email)
}
// === DeleteAccount — empty password string ===
func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
emptyPw := ""
_, err = service.DeleteAccount(user.ID, &emptyPw, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
}
// === SetNotificationRepository ===
func TestAuthService_SetNotificationRepository(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
notifRepo := repositories.NewNotificationRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
assert.Nil(t, service.notificationRepo)
service.SetNotificationRepository(notifRepo)
assert.NotNil(t, service.notificationRepo)
}
// === Register creates profile and notification preferences ===
func TestAuthService_Register_CreatesProfile(t *testing.T) {
service, userRepo := setupAuthService(t)
req := &requests.RegisterRequest{
Username: "profileuser",
Email: "profile@test.com",
Password: "Password123",
FirstName: "John",
LastName: "Doe",
}
resp, _, err := service.Register(req)
require.NoError(t, err)
assert.Equal(t, "profileuser", resp.User.Username)
// Profile should exist
profile, err := userRepo.GetOrCreateProfile(resp.User.ID)
require.NoError(t, err)
assert.NotNil(t, profile)
}

View File

@@ -4,9 +4,11 @@ import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/dto/requests"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
"github.com/treytartt/honeydue-api/internal/testutil"
)
@@ -20,6 +22,420 @@ func setupContractorService(t *testing.T) (*ContractorService, *repositories.Con
return service, contractorRepo, residenceRepo
}
// === CreateContractor ===
func TestContractorService_CreateContractor(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
req := &requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Bob's Plumbing",
Phone: "555-1234",
Email: "bob@plumbing.com",
}
resp, err := service.CreateContractor(req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "Bob's Plumbing", resp.Name)
assert.Equal(t, "555-1234", resp.Phone)
assert.Equal(t, "bob@plumbing.com", resp.Email)
}
func TestContractorService_CreateContractor_Personal(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// No residence ID - personal contractor
req := &requests.CreateContractorRequest{
Name: "Personal Handyman",
}
resp, err := service.CreateContractor(req, user.ID)
require.NoError(t, err)
assert.Equal(t, "Personal Handyman", resp.Name)
}
func TestContractorService_CreateContractor_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
req := &requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Unauthorized Contractor",
}
_, err := service.CreateContractor(req, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
func TestContractorService_CreateContractor_WithFavorite(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
isFav := true
req := &requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Fav Plumber",
IsFavorite: &isFav,
}
resp, err := service.CreateContractor(req, user.ID)
require.NoError(t, err)
assert.True(t, resp.IsFavorite)
}
// === GetContractor ===
func TestContractorService_GetContractor(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
resp, err := service.GetContractor(contractor.ID, user.ID)
require.NoError(t, err)
assert.Equal(t, contractor.ID, resp.ID)
assert.Equal(t, "Test Contractor", resp.Name)
}
func TestContractorService_GetContractor_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetContractor(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
func TestContractorService_GetContractor_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.GetContractor(contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
func TestContractorService_GetContractor_SharedUserHasAccess(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor")
resp, err := service.GetContractor(contractor.ID, shared.ID)
require.NoError(t, err)
assert.Equal(t, "Shared Contractor", resp.Name)
}
// === ListContractors ===
func TestContractorService_ListContractors(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
resp, err := service.ListContractors(user.ID)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
// === DeleteContractor ===
func TestContractorService_DeleteContractor(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
err := service.DeleteContractor(contractor.ID, user.ID)
require.NoError(t, err)
// Should not be found after deletion
_, err = service.GetContractor(contractor.ID, user.ID)
assert.Error(t, err)
}
func TestContractorService_DeleteContractor_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteContractor(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
func TestContractorService_DeleteContractor_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
err := service.DeleteContractor(contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
// === ToggleFavorite ===
func TestContractorService_ToggleFavorite(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
// Initially not favorite
resp, err := service.GetContractor(contractor.ID, user.ID)
require.NoError(t, err)
assert.False(t, resp.IsFavorite)
// Toggle to favorite
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
require.NoError(t, err)
assert.True(t, resp.IsFavorite)
// Toggle back
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
require.NoError(t, err)
assert.False(t, resp.IsFavorite)
}
func TestContractorService_ToggleFavorite_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.ToggleFavorite(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
func TestContractorService_ToggleFavorite_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.ToggleFavorite(contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
// === ListContractorsByResidence ===
func TestContractorService_ListContractorsByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
func TestContractorService_ListContractorsByResidence_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.ListContractorsByResidence(residence.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
// === GetContractorTasks ===
func TestContractorService_GetContractorTasks_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetContractorTasks(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
func TestContractorService_GetContractorTasks_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.GetContractorTasks(contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
func TestContractorService_GetContractorTasks_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
resp, err := service.GetContractorTasks(contractor.ID, user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
// === GetSpecialties ===
func TestContractorService_GetSpecialties(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
resp, err := service.GetSpecialties()
require.NoError(t, err)
// SeedLookupData creates 4 specialties
assert.Len(t, resp, 4)
}
// === UpdateContractor ===
func TestContractorService_UpdateContractor_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
newName := "Won't Work"
req := &requests.UpdateContractorRequest{Name: &newName}
_, err := service.UpdateContractor(9999, user.ID, req)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
func TestContractorService_UpdateContractor_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
newName := "Hacked"
req := &requests.UpdateContractorRequest{Name: &newName}
_, err := service.UpdateContractor(contractor.ID, other.ID, req)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
@@ -96,3 +512,171 @@ func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
require.NoError(t, err, "should allow removing residence association")
require.NotNil(t, resp)
}
// === UpdateContractor — partial update multiple fields ===
func TestContractorService_UpdateContractor_PartialUpdate(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Original Name")
newName := "Updated Plumber"
newPhone := "555-9999"
newEmail := "new@plumber.com"
newCompany := "Best Plumbing"
newWebsite := "https://bestplumbing.com"
newNotes := "Great work"
newStreet := "456 Plumber Ave"
newCity := "Dallas"
newState := "TX"
newPostal := "75001"
rating := 5.0
isFav := true
req := &requests.UpdateContractorRequest{
Name: &newName,
Phone: &newPhone,
Email: &newEmail,
Company: &newCompany,
Website: &newWebsite,
Notes: &newNotes,
StreetAddress: &newStreet,
City: &newCity,
StateProvince: &newState,
PostalCode: &newPostal,
Rating: &rating,
IsFavorite: &isFav,
ResidenceID: &residence.ID,
}
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated Plumber", resp.Name)
assert.Equal(t, "555-9999", resp.Phone)
assert.Equal(t, "new@plumber.com", resp.Email)
assert.True(t, resp.IsFavorite)
}
// === UpdateContractor — with specialties ===
func TestContractorService_UpdateContractor_WithSpecialties(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
// Get specialty IDs from seeded data
var specialties []models.ContractorSpecialty
err := db.Find(&specialties).Error
require.NoError(t, err)
require.NotEmpty(t, specialties)
specialtyIDs := []uint{specialties[0].ID, specialties[1].ID}
req := &requests.UpdateContractorRequest{
SpecialtyIDs: specialtyIDs,
ResidenceID: &residence.ID,
}
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
require.NoError(t, err)
assert.NotNil(t, resp)
}
// === CreateContractor — with specialties ===
func TestContractorService_CreateContractor_WithSpecialties(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
var specialties []models.ContractorSpecialty
err := db.Find(&specialties).Error
require.NoError(t, err)
req := &requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Specialized Plumber",
SpecialtyIDs: []uint{specialties[0].ID},
}
resp, err := service.CreateContractor(req, user.ID)
require.NoError(t, err)
assert.Equal(t, "Specialized Plumber", resp.Name)
}
// === ListContractors — empty result ===
func TestContractorService_ListContractors_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// No residence, no contractors
resp, err := service.ListContractors(user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
// === ListContractorsByResidence — empty result ===
func TestContractorService_ListContractorsByResidence_Empty(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House")
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
// === Personal contractor access — creator has access, others don't ===
func TestContractorService_PersonalContractor_OnlyCreatorAccess(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
creator := testutil.CreateTestUser(t, db, "creator", "creator@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
// Create personal contractor (no residence)
req := &requests.CreateContractorRequest{
Name: "Personal Plumber",
}
resp, err := service.CreateContractor(req, creator.ID)
require.NoError(t, err)
// Creator can access
_, err = service.GetContractor(resp.ID, creator.ID)
require.NoError(t, err)
// Other user cannot
_, err = service.GetContractor(resp.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}

View File

@@ -0,0 +1,764 @@
package services
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/dto/requests"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func setupDocumentService(t *testing.T) (*DocumentService, *repositories.DocumentRepository, *repositories.ResidenceRepository) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
return service, documentRepo, residenceRepo
}
// === CreateDocument ===
func TestDocumentService_CreateDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
req := &requests.CreateDocumentRequest{
ResidenceID: residence.ID,
Title: "Furnace Manual",
Description: "Installation manual for the furnace",
DocumentType: models.DocumentTypeManual,
FileURL: "https://example.com/manual.pdf",
FileName: "manual.pdf",
}
resp, err := service.CreateDocument(req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "Furnace Manual", resp.Title)
assert.Equal(t, models.DocumentTypeManual, resp.DocumentType)
}
func TestDocumentService_CreateDocument_DefaultType(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
req := &requests.CreateDocumentRequest{
ResidenceID: residence.ID,
Title: "Some Document",
// DocumentType not set — should default to "general"
}
resp, err := service.CreateDocument(req, user.ID)
require.NoError(t, err)
assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType)
}
func TestDocumentService_CreateDocument_WithImages(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
req := &requests.CreateDocumentRequest{
ResidenceID: residence.ID,
Title: "Receipt with photos",
ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"},
}
resp, err := service.CreateDocument(req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "Receipt with photos", resp.Title)
}
func TestDocumentService_CreateDocument_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
req := &requests.CreateDocumentRequest{
ResidenceID: residence.ID,
Title: "Unauthorized Doc",
}
_, err := service.CreateDocument(req, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
// === GetDocument ===
func TestDocumentService_GetDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
resp, err := service.GetDocument(doc.ID, user.ID)
require.NoError(t, err)
assert.Equal(t, doc.ID, resp.ID)
assert.Equal(t, "Test Doc", resp.Title)
}
func TestDocumentService_GetDocument_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetDocument(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
func TestDocumentService_GetDocument_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.GetDocument(doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
// === UpdateDocument ===
func TestDocumentService_UpdateDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original Title")
newTitle := "Updated Title"
newDesc := "Updated description"
req := &requests.UpdateDocumentRequest{
Title: &newTitle,
Description: &newDesc,
}
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated Title", resp.Title)
assert.Equal(t, "Updated description", resp.Description)
}
func TestDocumentService_UpdateDocument_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
newTitle := "Won't Work"
req := &requests.UpdateDocumentRequest{Title: &newTitle}
_, err := service.UpdateDocument(9999, user.ID, req)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
func TestDocumentService_UpdateDocument_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
newTitle := "Hacked"
req := &requests.UpdateDocumentRequest{Title: &newTitle}
_, err := service.UpdateDocument(doc.ID, other.ID, req)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
func TestDocumentService_UpdateDocument_ChangeType(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "My Receipt")
newType := models.DocumentTypeWarranty
req := &requests.UpdateDocumentRequest{DocumentType: &newType}
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
}
// === DeleteDocument ===
func TestDocumentService_DeleteDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
err := service.DeleteDocument(doc.ID, user.ID)
require.NoError(t, err)
// Should not be found after deletion
_, err = service.GetDocument(doc.ID, user.ID)
assert.Error(t, err)
}
func TestDocumentService_DeleteDocument_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDocument(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
func TestDocumentService_DeleteDocument_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
err := service.DeleteDocument(doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
// === ListDocuments ===
func TestDocumentService_ListDocuments(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
resp, err := service.ListDocuments(user.ID, nil)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
func TestDocumentService_ListDocuments_NoResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListDocuments(user.ID, nil)
require.NoError(t, err)
assert.Empty(t, resp)
}
func TestDocumentService_ListDocuments_FilterByResidence(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
residence2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
testutil.CreateTestDocument(t, db, residence1.ID, user.ID, "Doc A")
testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B")
filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID}
resp, err := service.ListDocuments(user.ID, filter)
require.NoError(t, err)
assert.Len(t, resp, 1)
assert.Equal(t, "Doc A", resp[0].Title)
}
func TestDocumentService_ListDocuments_FilterByResidence_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Owner House")
// other has their own residence so they have at least one
testutil.CreateTestResidence(t, db, other.ID, "Other House")
filter := &repositories.DocumentFilter{ResidenceID: &residence.ID}
_, err := service.ListDocuments(other.ID, filter)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
// === ListWarranties ===
func TestDocumentService_ListWarranties(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a warranty doc directly
warrantyDoc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "HVAC Warranty",
DocumentType: "warranty",
FileURL: "https://example.com/warranty.pdf",
}
err := db.Create(warrantyDoc).Error
require.NoError(t, err)
// Create a general doc
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
resp, err := service.ListWarranties(user.ID)
require.NoError(t, err)
assert.Len(t, resp, 1)
assert.Equal(t, "HVAC Warranty", resp[0].Title)
}
func TestDocumentService_ListWarranties_NoResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListWarranties(user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
// === DeactivateDocument ===
func TestDocumentService_DeactivateDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate")
resp, err := service.DeactivateDocument(doc.ID, user.ID)
require.NoError(t, err)
assert.False(t, resp.IsActive)
}
func TestDocumentService_DeactivateDocument_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.DeactivateDocument(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
func TestDocumentService_DeactivateDocument_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.DeactivateDocument(doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
// === UploadDocumentImage ===
func TestDocumentService_UploadDocumentImage(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
resp, err := service.UploadDocumentImage(doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
require.NoError(t, err)
assert.NotNil(t, resp)
}
func TestDocumentService_UploadDocumentImage_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.UploadDocumentImage(9999, user.ID, "https://example.com/photo.jpg", "")
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
func TestDocumentService_UploadDocumentImage_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.UploadDocumentImage(doc.ID, other.ID, "https://example.com/photo.jpg", "")
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
// === DeleteDocumentImage ===
func TestDocumentService_DeleteDocumentImage(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
// Create an image
img := &models.DocumentImage{
DocumentID: doc.ID,
ImageURL: "https://example.com/photo.jpg",
}
err := db.Create(img).Error
require.NoError(t, err)
resp, err := service.DeleteDocumentImage(doc.ID, img.ID, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
}
func TestDocumentService_DeleteDocumentImage_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
_, err := service.DeleteDocumentImage(doc.ID, 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
}
func TestDocumentService_DeleteDocumentImage_WrongDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc1 := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
doc2 := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
// Create an image on doc1
img := &models.DocumentImage{
DocumentID: doc1.ID,
ImageURL: "https://example.com/photo.jpg",
}
err := db.Create(img).Error
require.NoError(t, err)
// Try to delete the image specifying doc2
_, err = service.DeleteDocumentImage(doc2.ID, img.ID, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
}
func TestDocumentService_DeleteDocumentImage_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
img := &models.DocumentImage{
DocumentID: doc.ID,
ImageURL: "https://example.com/photo.jpg",
}
err := db.Create(img).Error
require.NoError(t, err)
_, err = service.DeleteDocumentImage(doc.ID, img.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
// === SharedUser access ===
func TestDocumentService_GetDocument_SharedUserHasAccess(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
resp, err := service.GetDocument(doc.ID, shared.ID)
require.NoError(t, err)
assert.Equal(t, "Shared Doc", resp.Title)
}
// === ActivateDocument ===
func TestDocumentService_ActivateDocument(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate")
// Deactivate first
_, err := service.DeactivateDocument(doc.ID, user.ID)
require.NoError(t, err)
// Now activate
resp, err := service.ActivateDocument(doc.ID, user.ID)
require.NoError(t, err)
assert.True(t, resp.IsActive)
}
func TestDocumentService_ActivateDocument_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.ActivateDocument(9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
func TestDocumentService_ActivateDocument_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.ActivateDocument(doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
// === CreateDocument — with empty image URL in array (should skip) ===
func TestDocumentService_CreateDocument_WithEmptyImageURL(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
req := &requests.CreateDocumentRequest{
ResidenceID: residence.ID,
Title: "Doc with empty images",
ImageURLs: []string{"", "https://example.com/img.jpg", ""},
}
resp, err := service.CreateDocument(req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
}
// === UpdateDocument — all optional fields ===
func TestDocumentService_UpdateDocument_AllFields(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original")
newTitle := "Updated"
newDesc := "New description"
newType := models.DocumentTypeWarranty
newFileURL := "https://example.com/new.pdf"
newFileName := "new.pdf"
newMimeType := "application/pdf"
newVendor := "HVAC Corp"
newSerial := "SN12345"
newModel := "Model X"
size := int64(1024)
req := &requests.UpdateDocumentRequest{
Title: &newTitle,
Description: &newDesc,
DocumentType: &newType,
FileURL: &newFileURL,
FileName: &newFileName,
FileSize: &size,
MimeType: &newMimeType,
Vendor: &newVendor,
SerialNumber: &newSerial,
ModelNumber: &newModel,
}
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated", resp.Title)
assert.Equal(t, "New description", resp.Description)
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
}
// === ListDocuments — filter by document type ===
func TestDocumentService_ListDocuments_FilterByType(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a warranty doc
warrantyDoc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Warranty Doc",
DocumentType: models.DocumentTypeWarranty,
FileURL: "https://example.com/w.pdf",
}
err := db.Create(warrantyDoc).Error
require.NoError(t, err)
// Create a general doc
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)}
resp, err := service.ListDocuments(user.ID, filter)
require.NoError(t, err)
assert.Len(t, resp, 1)
assert.Equal(t, "Warranty Doc", resp[0].Title)
}
// === Shared user can update/delete documents ===
func TestDocumentService_SharedUser_CanUpdate(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
newTitle := "Updated by shared user"
req := &requests.UpdateDocumentRequest{Title: &newTitle}
resp, err := service.UpdateDocument(doc.ID, shared.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated by shared user", resp.Title)
}
func TestDocumentService_SharedUser_CanDelete(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewDocumentService(documentRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
err := service.DeleteDocument(doc.ID, shared.ID)
require.NoError(t, err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -456,3 +456,631 @@ func TestCreateResidence_ProTier_AllowsMore(t *testing.T) {
func ptrTime(t time.Time) *time.Time {
return &t
}
// === GetMyResidences ===
func TestResidenceService_GetMyResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
testutil.CreateTestResidence(t, db, user.ID, "House 1")
testutil.CreateTestResidence(t, db, user.ID, "House 2")
resp, err := service.GetMyResidences(user.ID, time.Now())
require.NoError(t, err)
assert.Len(t, resp.Residences, 2)
}
func TestResidenceService_GetMyResidences_NoResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.GetMyResidences(user.ID, time.Now())
require.NoError(t, err)
assert.Empty(t, resp.Residences)
}
// === GetSummary ===
func TestResidenceService_GetSummary(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
testutil.CreateTestResidence(t, db, user.ID, "House 1")
testutil.CreateTestResidence(t, db, user.ID, "House 2")
resp, err := service.GetSummary(user.ID, time.Now())
require.NoError(t, err)
assert.Equal(t, 2, resp.TotalResidences)
}
func TestResidenceService_GetSummary_NoResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.GetSummary(user.ID, time.Now())
require.NoError(t, err)
assert.Equal(t, 0, resp.TotalResidences)
}
// === GetShareCode ===
func TestResidenceService_GetShareCode_NoActiveCode(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
resp, err := service.GetShareCode(residence.ID, user.ID)
require.NoError(t, err)
assert.Nil(t, resp) // No active code
}
func TestResidenceService_GetShareCode_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.GetShareCode(residence.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
// === GenerateShareCode ===
func TestResidenceService_GenerateShareCode_NotOwner(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
_, err := service.GenerateShareCode(residence.ID, shared.ID, 24)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
}
func TestResidenceService_GenerateShareCode_DefaultExpiry(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Pass 0 hours — should default to 24
resp, err := service.GenerateShareCode(residence.ID, user.ID, 0)
require.NoError(t, err)
assert.NotEmpty(t, resp.ShareCode.Code)
}
// === GenerateSharePackage ===
func TestResidenceService_GenerateSharePackage(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
resp, err := service.GenerateSharePackage(residence.ID, user.ID, 48)
require.NoError(t, err)
assert.NotEmpty(t, resp.ShareCode)
assert.Equal(t, "Test House", resp.ResidenceName)
assert.Equal(t, "owner@test.com", resp.SharedBy)
}
func TestResidenceService_GenerateSharePackage_NotOwner(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
_, err := service.GenerateSharePackage(residence.ID, shared.ID, 24)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
}
// === JoinWithCode ===
func TestResidenceService_JoinWithCode_InvalidCode(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "user", "user@test.com", "Password123")
_, err := service.JoinWithCode("BADCODE", user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.share_code_invalid")
}
// === RemoveUser ===
func TestResidenceService_RemoveUser_NotOwner(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
// shared user tries to remove other — should fail because shared is not owner
err := service.RemoveUser(residence.ID, other.ID, shared.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
}
// === GetResidenceUsers ===
func TestResidenceService_GetResidenceUsers_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.GetResidenceUsers(residence.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
// === GetResidenceTypes ===
func TestResidenceService_GetResidenceTypes(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
resp, err := service.GetResidenceTypes()
require.NoError(t, err)
// SeedLookupData creates 4 residence types
assert.Len(t, resp, 4)
}
// === UpdateResidence with home profile fields ===
func TestResidenceService_UpdateResidence_HomeProfileFields(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
hasPool := true
hasGarage := true
heatingType := "Forced Air"
req := &requests.UpdateResidenceRequest{
HasPool: &hasPool,
HasGarage: &hasGarage,
HeatingType: &heatingType,
}
resp, err := service.UpdateResidence(residence.ID, user.ID, req)
require.NoError(t, err)
assert.True(t, resp.Data.HasPool)
assert.True(t, resp.Data.HasGarage)
}
// === CreateResidence with home profile fields ===
func TestResidenceService_CreateResidence_HomeProfileFields(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
hasPool := true
hasSeptic := true
req := &requests.CreateResidenceRequest{
Name: "New House",
StreetAddress: "456 Oak St",
City: "Dallas",
StateProvince: "TX",
PostalCode: "75201",
HasPool: &hasPool,
HasSeptic: &hasSeptic,
}
resp, err := service.CreateResidence(req, user.ID)
require.NoError(t, err)
assert.True(t, resp.Data.HasPool)
assert.True(t, resp.Data.HasSeptic)
}
// === Shared user GetResidence ===
func TestResidenceService_GetResidence_SharedUser(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, shared.ID)
resp, err := service.GetResidence(residence.ID, shared.ID, time.Now())
require.NoError(t, err)
assert.Equal(t, "Test House", resp.Name)
}
// === GetMyResidences with task repo (overdue counts + completion summaries) ===
func TestResidenceService_GetMyResidences_WithTaskRepo(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
taskRepo := repositories.NewTaskRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
service.SetTaskRepository(taskRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
testutil.CreateTestResidence(t, db, user.ID, "House 1")
testutil.CreateTestResidence(t, db, user.ID, "House 2")
resp, err := service.GetMyResidences(user.ID, time.Now())
require.NoError(t, err)
assert.Len(t, resp.Residences, 2)
}
// === GetResidence with task repo (completion summary) ===
func TestResidenceService_GetResidence_WithTaskRepo(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
taskRepo := repositories.NewTaskRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
service.SetTaskRepository(taskRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
resp, err := service.GetResidence(residence.ID, user.ID, time.Now())
require.NoError(t, err)
assert.Equal(t, "Test House", resp.Name)
}
// === GenerateShareCode with negative expiry defaults to 24 ===
func TestResidenceService_GenerateShareCode_NegativeExpiry(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
resp, err := service.GenerateShareCode(residence.ID, user.ID, -5)
require.NoError(t, err)
assert.NotEmpty(t, resp.ShareCode.Code)
}
// === GenerateSharePackage with default expiry ===
func TestResidenceService_GenerateSharePackage_DefaultExpiry(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Pass 0 hours — should default to 24
resp, err := service.GenerateSharePackage(residence.ID, user.ID, 0)
require.NoError(t, err)
assert.NotEmpty(t, resp.ShareCode)
assert.Equal(t, "Test House", resp.ResidenceName)
}
// === RemoveUser — trying to remove the owner by a different owner ID ===
func TestResidenceService_RemoveUser_OwnerViaResidenceOwnerID(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
sharedUser := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
residenceRepo.AddUser(residence.ID, sharedUser.ID)
// Try removing the owner (by residence.OwnerID) — even though requestingUserID != userIDToRemove
// The second check (userIDToRemove == residence.OwnerID) should catch this
err := service.RemoveUser(residence.ID, owner.ID, owner.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.cannot_remove_owner")
}
// === GenerateTasksReport ===
func TestResidenceService_GenerateTasksReport(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create some tasks
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 1")
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 2")
report, err := service.GenerateTasksReport(residence.ID, user.ID)
require.NoError(t, err)
assert.Equal(t, residence.ID, report.ResidenceID)
assert.Equal(t, "Test House", report.ResidenceName)
assert.Equal(t, 2, report.TotalTasks)
}
func TestResidenceService_GenerateTasksReport_AccessDenied(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.GenerateTasksReport(residence.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
func TestResidenceService_GenerateTasksReport_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Non-existent residence — user has no access
_, err := service.GenerateTasksReport(9999, user.ID)
assert.Error(t, err)
}
// === GetShareCode with active code ===
func TestResidenceService_GetShareCode_WithActiveCode(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Generate a share code first
_, err := service.GenerateShareCode(residence.ID, user.ID, 24)
require.NoError(t, err)
// Now get the active code
resp, err := service.GetShareCode(residence.ID, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.NotEmpty(t, resp.Code)
}
// === CreateResidence with all boolean fields ===
func TestResidenceService_CreateResidence_AllBooleanFields(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
hasPool := true
hasSprinkler := true
hasSeptic := true
hasFireplace := true
hasGarage := true
hasBasement := true
hasAttic := true
req := &requests.CreateResidenceRequest{
Name: "Full Feature House",
StreetAddress: "789 Full St",
City: "Austin",
StateProvince: "TX",
PostalCode: "78701",
HasPool: &hasPool,
HasSprinklerSystem: &hasSprinkler,
HasSeptic: &hasSeptic,
HasFireplace: &hasFireplace,
HasGarage: &hasGarage,
HasBasement: &hasBasement,
HasAttic: &hasAttic,
}
resp, err := service.CreateResidence(req, user.ID)
require.NoError(t, err)
assert.True(t, resp.Data.HasPool)
assert.True(t, resp.Data.HasSprinklerSystem)
assert.True(t, resp.Data.HasSeptic)
assert.True(t, resp.Data.HasFireplace)
assert.True(t, resp.Data.HasGarage)
assert.True(t, resp.Data.HasBasement)
assert.True(t, resp.Data.HasAttic)
}
// === UpdateResidence with all optional fields ===
func TestResidenceService_UpdateResidence_AllOptionalFields(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Original Name")
newStreet := "456 New St"
newApt := "Apt 2B"
newState := "CA"
newPostal := "90210"
newCountry := "Canada"
bedrooms := 4
bathrooms := decimal.NewFromFloat(3.0)
sqft := 3000
lotSize := decimal.NewFromFloat(0.5)
yearBuilt := 2020
newDesc := "Nice house"
isPrimary := false
hasPool := true
hasSprinkler := true
hasSeptic := false
hasFireplace := true
hasGarage := true
hasBasement := false
hasAttic := true
coolingType := "Central AC"
waterHeaterType := "Tankless"
roofType := "Shingle"
exteriorType := "Brick"
flooringPrimary := "Hardwood"
landscapingType := "Xeriscape"
req := &requests.UpdateResidenceRequest{
StreetAddress: &newStreet,
ApartmentUnit: &newApt,
StateProvince: &newState,
PostalCode: &newPostal,
Country: &newCountry,
Bedrooms: &bedrooms,
Bathrooms: &bathrooms,
SquareFootage: &sqft,
LotSize: &lotSize,
YearBuilt: &yearBuilt,
Description: &newDesc,
IsPrimary: &isPrimary,
HasPool: &hasPool,
HasSprinklerSystem: &hasSprinkler,
HasSeptic: &hasSeptic,
HasFireplace: &hasFireplace,
HasGarage: &hasGarage,
HasBasement: &hasBasement,
HasAttic: &hasAttic,
CoolingType: &coolingType,
WaterHeaterType: &waterHeaterType,
RoofType: &roofType,
ExteriorType: &exteriorType,
FlooringPrimary: &flooringPrimary,
LandscapingType: &landscapingType,
}
resp, err := service.UpdateResidence(residence.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "456 New St", resp.Data.StreetAddress)
assert.Equal(t, "CA", resp.Data.StateProvince)
assert.True(t, resp.Data.HasPool)
assert.True(t, resp.Data.HasFireplace)
assert.True(t, resp.Data.HasAttic)
}
// === ListResidences with no residences ===
func TestResidenceService_ListResidences_NoResidences(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListResidences(user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
// === getSummaryForUser returns empty summary ===
func TestResidenceService_getSummaryForUser_ReturnsEmpty(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
summary := service.getSummaryForUser(999)
assert.Equal(t, 0, summary.TotalResidences)
}

View File

@@ -162,3 +162,179 @@ func TestDelete_NonexistentFile(t *testing.T) {
t.Fatalf("Delete should not error for non-existent file: %v", err)
}
}
// === isAllowedType ===
func TestIsAllowedType(t *testing.T) {
cfg := &config.StorageConfig{
UploadDir: t.TempDir(),
BaseURL: "/uploads",
MaxFileSize: 10 * 1024 * 1024,
AllowedTypes: "image/jpeg,image/png,application/pdf",
}
svc := NewStorageServiceForTest(cfg)
if !svc.isAllowedType("image/jpeg") {
t.Fatal("image/jpeg should be allowed")
}
if !svc.isAllowedType("image/png") {
t.Fatal("image/png should be allowed")
}
if !svc.isAllowedType("application/pdf") {
t.Fatal("application/pdf should be allowed")
}
if svc.isAllowedType("text/html") {
t.Fatal("text/html should not be allowed")
}
if svc.isAllowedType("") {
t.Fatal("empty MIME should not be allowed")
}
}
// === mimeTypesCompatible ===
func TestMimeTypesCompatible(t *testing.T) {
cfg := &config.StorageConfig{
UploadDir: t.TempDir(),
BaseURL: "/uploads",
MaxFileSize: 10 * 1024 * 1024,
AllowedTypes: "image/jpeg",
}
svc := NewStorageServiceForTest(cfg)
// Same primary type
if !svc.mimeTypesCompatible("image/jpeg", "image/png") {
t.Fatal("image/* types should be compatible")
}
// Different primary types
if svc.mimeTypesCompatible("image/jpeg", "application/pdf") {
t.Fatal("image and application should not be compatible")
}
// Same exact types
if !svc.mimeTypesCompatible("application/pdf", "application/octet-stream") {
t.Fatal("application/* types should be compatible")
}
}
// === getExtensionFromMimeType ===
func TestGetExtensionFromMimeType(t *testing.T) {
cfg := &config.StorageConfig{
UploadDir: t.TempDir(),
BaseURL: "/uploads",
MaxFileSize: 10 * 1024 * 1024,
AllowedTypes: "image/jpeg",
}
svc := NewStorageServiceForTest(cfg)
tests := []struct {
mimeType string
expected string
}{
{"image/jpeg", ".jpg"},
{"image/png", ".png"},
{"image/gif", ".gif"},
{"image/webp", ".webp"},
{"application/pdf", ".pdf"},
{"text/html", ""},
{"unknown/type", ""},
}
for _, tt := range tests {
got := svc.getExtensionFromMimeType(tt.mimeType)
if got != tt.expected {
t.Fatalf("getExtensionFromMimeType(%q) = %q, want %q", tt.mimeType, got, tt.expected)
}
}
}
// === GetUploadDir ===
func TestGetUploadDir(t *testing.T) {
svc, tmpDir := setupTestStorage(t, false)
if svc.GetUploadDir() != tmpDir {
t.Fatalf("GetUploadDir() = %q, want %q", svc.GetUploadDir(), tmpDir)
}
}
// === SetEncryptionService ===
func TestSetEncryptionService(t *testing.T) {
svc, _ := setupTestStorage(t, false)
// Initially no encryption service
if svc.encryptionSvc != nil && svc.encryptionSvc.IsEnabled() {
t.Fatal("encryption should not be enabled initially in plain mode")
}
encSvc, err := NewEncryptionService(validTestKey())
if err != nil {
t.Fatal(err)
}
svc.SetEncryptionService(encSvc)
if svc.encryptionSvc == nil {
t.Fatal("encryption service should be set")
}
}
// === NewStorageServiceForTest ===
func TestNewStorageServiceForTest(t *testing.T) {
cfg := &config.StorageConfig{
UploadDir: "/tmp/test",
BaseURL: "/uploads",
MaxFileSize: 5 * 1024 * 1024,
AllowedTypes: "image/jpeg, image/png, application/pdf",
}
svc := NewStorageServiceForTest(cfg)
// Should have 3 allowed types (whitespace trimmed)
if !svc.isAllowedType("image/jpeg") {
t.Fatal("image/jpeg should be allowed")
}
if !svc.isAllowedType("image/png") {
t.Fatal("image/png should be allowed")
}
if !svc.isAllowedType("application/pdf") {
t.Fatal("application/pdf should be allowed")
}
}
// === Delete only plain file ===
func TestDelete_OnlyPlainFile(t *testing.T) {
svc, tmpDir := setupTestStorage(t, false)
dir := filepath.Join(tmpDir, "images")
os.WriteFile(filepath.Join(dir, "only-plain.jpg"), []byte("plain"), 0644)
err := svc.Delete("/uploads/images/only-plain.jpg")
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "only-plain.jpg")); !os.IsNotExist(err) {
t.Fatal("plain file should be deleted")
}
}
// === Delete only enc file ===
func TestDelete_OnlyEncFile(t *testing.T) {
svc, tmpDir := setupTestStorage(t, false)
dir := filepath.Join(tmpDir, "documents")
os.WriteFile(filepath.Join(dir, "secret.pdf.enc"), []byte("encrypted"), 0644)
err := svc.Delete("/uploads/documents/secret.pdf")
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "secret.pdf.enc")); !os.IsNotExist(err) {
t.Fatal("encrypted file should be deleted")
}
}

View File

@@ -181,6 +181,107 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
}
// === GetSubscription ===
func TestSubscriptionService_GetSubscription(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
}
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := svc.GetSubscription(user.ID)
require.NoError(t, err)
assert.Equal(t, "free", resp.Tier)
assert.False(t, resp.IsPro)
}
func TestSubscriptionService_GetSubscription_ProUser(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
}
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create a pro subscription
future := time.Now().UTC().Add(30 * 24 * time.Hour)
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
ExpiresAt: &future,
Platform: "ios",
}
err := db.Create(sub).Error
require.NoError(t, err)
resp, err := svc.GetSubscription(user.ID)
require.NoError(t, err)
assert.Equal(t, "pro", resp.Tier)
assert.True(t, resp.IsPro)
assert.True(t, resp.IsActive)
}
// === CancelSubscription ===
func TestSubscriptionService_CancelSubscription(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
}
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Create a pro subscription with auto_renew
future := time.Now().UTC().Add(30 * 24 * time.Hour)
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
ExpiresAt: &future,
AutoRenew: true,
}
err := db.Create(sub).Error
require.NoError(t, err)
resp, err := svc.CancelSubscription(user.ID)
require.NoError(t, err)
assert.False(t, resp.AutoRenew)
}
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
future := time.Now().UTC().Add(30 * 24 * time.Hour)

View File

@@ -231,3 +231,472 @@ func TestSuggestionService_MultipleConditionsAllMustMatch(t *testing.T) {
assert.InDelta(t, expectedScore, resp.Suggestions[0].RelevanceScore, 0.01)
assert.Len(t, resp.Suggestions[0].MatchReasons, 3)
}
// === Malformed conditions JSON treated as universal ===
func TestSuggestionService_MalformedConditions(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Test House")
// Create template with malformed JSON conditions
tmpl := &models.TaskTemplate{
Title: "Bad JSON Template",
IsActive: true,
Conditions: json.RawMessage(`{bad json`),
}
err := service.db.Create(tmpl).Error
require.NoError(t, err)
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
// Should be treated as universal
assert.Equal(t, baseUniversalScore, resp.Suggestions[0].RelevanceScore)
}
// === Null conditions JSON treated as universal ===
func TestSuggestionService_NullConditions(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Test House")
tmpl := &models.TaskTemplate{
Title: "Null Conditions",
IsActive: true,
Conditions: json.RawMessage(`null`),
}
err := service.db.Create(tmpl).Error
require.NoError(t, err)
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Equal(t, baseUniversalScore, resp.Suggestions[0].RelevanceScore)
}
// === Template with property_type condition ===
func TestSuggestionService_PropertyTypeMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
// Create a property type
propType := &models.ResidenceType{Name: "House"}
err := service.db.Create(propType).Error
require.NoError(t, err)
residence := &models.Residence{
OwnerID: user.ID,
Name: "My House",
IsActive: true,
IsPrimary: true,
PropertyTypeID: &propType.ID,
}
err = service.db.Create(residence).Error
require.NoError(t, err)
// Reload with PropertyType preloaded
err = service.db.Preload("PropertyType").First(residence, residence.ID).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "House Task", map[string]interface{}{
"property_type": "House",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "property_type:House")
}
// === CalculateProfileCompleteness with fully filled profile ===
func TestCalculateProfileCompleteness_FullProfile(t *testing.T) {
ht := "gas_furnace"
ct := "central_ac"
wht := "tank_gas"
rt := "asphalt_shingle"
et := "brick"
fp := "hardwood"
lt := "lawn"
residence := &models.Residence{
HeatingType: &ht,
CoolingType: &ct,
WaterHeaterType: &wht,
RoofType: &rt,
HasPool: true,
HasSprinklerSystem: true,
HasSeptic: true,
HasFireplace: true,
HasGarage: true,
HasBasement: true,
HasAttic: true,
ExteriorType: &et,
FlooringPrimary: &fp,
LandscapingType: &lt,
}
completeness := CalculateProfileCompleteness(residence)
assert.Equal(t, 1.0, completeness)
}
// === CalculateProfileCompleteness with empty profile ===
func TestCalculateProfileCompleteness_EmptyProfile(t *testing.T) {
residence := &models.Residence{}
completeness := CalculateProfileCompleteness(residence)
assert.Equal(t, 0.0, completeness)
}
// === Score capped at 1.0 ===
func TestSuggestionService_ScoreCappedAtOne(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
ht := "gas_furnace"
ct := "central_ac"
wht := "tank_gas"
rt := "asphalt_shingle"
et := "brick"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Full House",
IsActive: true,
IsPrimary: true,
HeatingType: &ht,
CoolingType: &ct,
WaterHeaterType: &wht,
RoofType: &rt,
ExteriorType: &et,
HasPool: true,
HasFireplace: true,
HasGarage: true,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
// Template that matches many fields — score should be capped at 1.0
createTemplateWithConditions(t, service, "Super Match", map[string]interface{}{
"heating_type": "gas_furnace",
"cooling_type": "central_ac",
"water_heater_type": "tank_gas",
"roof_type": "asphalt_shingle",
"exterior_type": "brick",
"has_pool": true,
"has_fireplace": true,
"has_garage": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.LessOrEqual(t, resp.Suggestions[0].RelevanceScore, 1.0)
}
// === Inactive templates are excluded ===
func TestSuggestionService_InactiveTemplateExcluded(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Test House")
// Create inactive template via raw SQL to bypass GORM default:true on is_active
err := service.db.Exec("INSERT INTO task_tasktemplate (title, is_active, conditions, created_at, updated_at) VALUES ('Inactive Task', false, '{}', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)").Error
require.NoError(t, err)
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
// === Template excluded when requires sprinkler but residence doesn't have it ===
func TestSuggestionService_ExcludedWhenSprinklerRequired(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "No Sprinkler House")
createTemplateWithConditions(t, service, "Sprinkler Maintenance", map[string]interface{}{
"has_sprinkler_system": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
// === All bool field exclusions ===
func TestSuggestionService_ExcludedWhenSepticRequired(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "City House")
createTemplateWithConditions(t, service, "Septic Pump", map[string]interface{}{
"has_septic": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
func TestSuggestionService_ExcludedWhenFireplaceRequired(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "No Fireplace")
createTemplateWithConditions(t, service, "Chimney Sweep", map[string]interface{}{
"has_fireplace": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
func TestSuggestionService_ExcludedWhenGarageRequired(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "No Garage")
createTemplateWithConditions(t, service, "Garage Door Service", map[string]interface{}{
"has_garage": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
func TestSuggestionService_ExcludedWhenBasementRequired(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Slab Home")
createTemplateWithConditions(t, service, "Basement Waterproofing", map[string]interface{}{
"has_basement": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
func TestSuggestionService_ExcludedWhenAtticRequired(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Flat Roof")
createTemplateWithConditions(t, service, "Attic Insulation", map[string]interface{}{
"has_attic": true,
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp.Suggestions, 0)
}
// === String field matches for all remaining types ===
func TestSuggestionService_CoolingTypeMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
coolingType := "central_ac"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Cool House",
IsActive: true,
IsPrimary: true,
CoolingType: &coolingType,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "AC Filter", map[string]interface{}{
"cooling_type": "central_ac",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "cooling_type:central_ac")
}
func TestSuggestionService_WaterHeaterTypeMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
wht := "tank_gas"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Hot Water House",
IsActive: true,
IsPrimary: true,
WaterHeaterType: &wht,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "Flush Water Heater", map[string]interface{}{
"water_heater_type": "tank_gas",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "water_heater_type:tank_gas")
}
func TestSuggestionService_ExteriorTypeMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
et := "brick"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Brick House",
IsActive: true,
IsPrimary: true,
ExteriorType: &et,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "Pressure Wash Brick", map[string]interface{}{
"exterior_type": "brick",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "exterior_type:brick")
}
func TestSuggestionService_FlooringPrimaryMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
fp := "hardwood"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Hardwood House",
IsActive: true,
IsPrimary: true,
FlooringPrimary: &fp,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "Refinish Floors", map[string]interface{}{
"flooring_primary": "hardwood",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "flooring_primary:hardwood")
}
func TestSuggestionService_LandscapingTypeMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
lt := "lawn"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Lawn House",
IsActive: true,
IsPrimary: true,
LandscapingType: &lt,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "Fertilize Lawn", map[string]interface{}{
"landscaping_type": "lawn",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "landscaping_type:lawn")
}
func TestSuggestionService_RoofTypeMatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
rt := "asphalt_shingle"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Shingle House",
IsActive: true,
IsPrimary: true,
RoofType: &rt,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
createTemplateWithConditions(t, service, "Inspect Roof", map[string]interface{}{
"roof_type": "asphalt_shingle",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "roof_type:asphalt_shingle")
}
// === Mismatch on string field — no score for that field ===
func TestSuggestionService_HeatingTypeMismatch(t *testing.T) {
service := setupSuggestionService(t)
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
heatingType := "electric_furnace"
residence := &models.Residence{
OwnerID: user.ID,
Name: "Electric House",
IsActive: true,
IsPrimary: true,
HeatingType: &heatingType,
}
err := service.db.Create(residence).Error
require.NoError(t, err)
// Template wants gas_furnace but residence has electric_furnace
createTemplateWithConditions(t, service, "Gas Furnace Service", map[string]interface{}{
"heating_type": "gas_furnace",
})
resp, err := service.GetSuggestions(residence.ID, user.ID)
require.NoError(t, err)
require.Len(t, resp.Suggestions, 1)
// Should still be included but with partial_profile (no match, no exclude)
assert.Contains(t, resp.Suggestions[0].MatchReasons, "partial_profile")
}
// === templateConditions.isEmpty ===
func TestTemplateConditions_IsEmpty(t *testing.T) {
cond := &templateConditions{}
assert.True(t, cond.isEmpty())
ht := "gas"
cond2 := &templateConditions{HeatingType: &ht}
assert.False(t, cond2.isEmpty())
pool := true
cond3 := &templateConditions{HasPool: &pool}
assert.False(t, cond3.isEmpty())
pt := "House"
cond4 := &templateConditions{PropertyType: &pt}
assert.False(t, cond4.isEmpty())
}

File diff suppressed because it is too large Load Diff

467
internal/task/task_test.go Normal file
View File

@@ -0,0 +1,467 @@
package task
import (
"testing"
"time"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/task/categorization"
"github.com/treytartt/honeydue-api/internal/testutil"
)
var now = time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
func timePtr(t time.Time) *time.Time { return &t }
func completedTask() *models.Task {
return &models.Task{
BaseModel: models.BaseModel{ID: 1},
Completions: []models.TaskCompletion{{BaseModel: models.BaseModel{ID: 1}}},
// NextDueDate is nil → completed
}
}
func overdueTask() *models.Task {
past := now.AddDate(0, 0, -5)
return &models.Task{
BaseModel: models.BaseModel{ID: 2},
DueDate: &past,
}
}
func dueSoonTask() *models.Task {
soon := now.AddDate(0, 0, 10)
return &models.Task{
BaseModel: models.BaseModel{ID: 3},
DueDate: &soon,
}
}
func activeTask() *models.Task {
return &models.Task{
BaseModel: models.BaseModel{ID: 4},
}
}
// --- Predicate re-exports ---
func TestReExport_IsCompleted(t *testing.T) {
if !IsCompleted(completedTask()) {
t.Error("expected completed")
}
if IsCompleted(activeTask()) {
t.Error("expected not completed")
}
}
func TestReExport_IsOverdue(t *testing.T) {
if !IsOverdue(overdueTask(), now) {
t.Error("expected overdue")
}
if IsOverdue(dueSoonTask(), now) {
t.Error("expected not overdue")
}
}
func TestReExport_IsDueSoon(t *testing.T) {
if !IsDueSoon(dueSoonTask(), now, 30) {
t.Error("expected due soon")
}
}
func TestReExport_IsActive(t *testing.T) {
if !IsActive(activeTask()) {
t.Error("expected active")
}
cancelled := &models.Task{IsCancelled: true}
if IsActive(cancelled) {
t.Error("cancelled should not be active")
}
}
func TestReExport_IsArchived(t *testing.T) {
archived := &models.Task{IsArchived: true}
if !IsArchived(archived) {
t.Error("expected archived")
}
}
func TestReExport_IsCancelled(t *testing.T) {
cancelled := &models.Task{IsCancelled: true}
if !IsCancelled(cancelled) {
t.Error("expected cancelled")
}
}
func TestReExport_IsInProgress(t *testing.T) {
ip := &models.Task{InProgress: true}
if !IsInProgress(ip) {
t.Error("expected in progress")
}
}
func TestReExport_IsRecurring(t *testing.T) {
days := 30
freqID := uint(1)
recurring := &models.Task{
FrequencyID: &freqID,
Frequency: &models.TaskFrequency{Days: &days},
}
if !IsRecurring(recurring) {
t.Error("expected recurring")
}
if IsRecurring(activeTask()) {
t.Error("expected not recurring")
}
}
func TestReExport_IsOneTime(t *testing.T) {
if !IsOneTime(activeTask()) {
t.Error("expected one-time")
}
}
func TestReExport_HasCompletions(t *testing.T) {
if !HasCompletions(completedTask()) {
t.Error("expected has completions")
}
if HasCompletions(activeTask()) {
t.Error("expected no completions")
}
}
func TestReExport_GetCompletionCount(t *testing.T) {
ct := completedTask()
if GetCompletionCount(ct) != 1 {
t.Errorf("count = %d, want 1", GetCompletionCount(ct))
}
}
func TestReExport_EffectiveDate(t *testing.T) {
task := overdueTask()
ed := EffectiveDate(task)
if ed == nil {
t.Error("expected non-nil effective date")
}
// If NextDueDate is set, prefer it
next := now.AddDate(0, 1, 0)
task.NextDueDate = &next
ed = EffectiveDate(task)
if !ed.Equal(next) {
t.Errorf("expected NextDueDate, got %v", ed)
}
}
func TestReExport_IsUpcoming(t *testing.T) {
far := now.AddDate(0, 6, 0)
task := &models.Task{DueDate: &far}
if !IsUpcoming(task, now, 30) {
t.Error("expected upcoming")
}
}
func TestReExport_CategorizeTask(t *testing.T) {
col := CategorizeTask(overdueTask(), 30)
if col != categorization.ColumnOverdue {
t.Errorf("column = %v, want overdue", col)
}
}
func TestReExport_DetermineKanbanColumn(t *testing.T) {
col := DetermineKanbanColumn(overdueTask(), 30)
if col == "" {
t.Error("expected non-empty column string")
}
}
func TestReExport_CategorizeTasksIntoColumns(t *testing.T) {
tasks := []models.Task{*overdueTask(), *dueSoonTask(), *activeTask()}
result := CategorizeTasksIntoColumns(tasks, 30)
if result == nil {
t.Error("expected non-nil result")
}
}
func TestReExport_NewChain(t *testing.T) {
chain := NewChain()
if chain == nil {
t.Error("expected non-nil chain")
}
}
func TestReExport_Constants(t *testing.T) {
if ColumnOverdue.String() != "overdue_tasks" {
t.Errorf("ColumnOverdue = %q", ColumnOverdue.String())
}
if ColumnDueSoon.String() != "due_soon_tasks" {
t.Errorf("ColumnDueSoon = %q", ColumnDueSoon.String())
}
if ColumnUpcoming.String() != "upcoming_tasks" {
t.Errorf("ColumnUpcoming = %q", ColumnUpcoming.String())
}
if ColumnInProgress.String() != "in_progress_tasks" {
t.Errorf("ColumnInProgress = %q", ColumnInProgress.String())
}
if ColumnCompleted.String() != "completed_tasks" {
t.Errorf("ColumnCompleted = %q", ColumnCompleted.String())
}
if ColumnCancelled.String() != "cancelled_tasks" {
t.Errorf("ColumnCancelled = %q", ColumnCancelled.String())
}
}
// --- Scope re-exports (use SQLite in-memory DB) ---
func setupDB(t *testing.T) *gorm.DB {
return testutil.SetupTestDB(t)
}
func seedTask(t *testing.T, db *gorm.DB, task *models.Task) {
// Ensure we have a user and residence
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
res := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task.CreatedByID = user.ID
task.ResidenceID = res.ID
task.Version = 1
err := db.Create(task).Error
if err != nil {
t.Fatalf("failed to create task: %v", err)
}
}
func TestReExport_ScopeActive_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "active"})
var tasks []models.Task
db.Scopes(ScopeActive).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeCancelled_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "cancelled", IsCancelled: true})
var tasks []models.Task
db.Scopes(ScopeCancelled).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeArchived_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "archived", IsArchived: true})
var tasks []models.Task
db.Scopes(ScopeArchived).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeInProgress_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "inprog", InProgress: true})
var tasks []models.Task
db.Scopes(ScopeInProgress).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeNotInProgress_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "not inprog"})
var tasks []models.Task
db.Scopes(ScopeNotInProgress).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeCompleted_DB(t *testing.T) {
db := setupDB(t)
user := testutil.CreateTestUser(t, db, "u2", "u2@test.com", "password123")
res := testutil.CreateTestResidence(t, db, user.ID, "House2")
task := &models.Task{Title: "completed", CreatedByID: user.ID, ResidenceID: res.ID, Version: 1}
db.Create(task)
// Add a completion and ensure NextDueDate is nil
db.Create(&models.TaskCompletion{TaskID: task.ID, CompletedByID: user.ID, CompletedAt: now})
var tasks []models.Task
db.Scopes(ScopeCompleted).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeNotCompleted_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "not completed"})
var tasks []models.Task
db.Scopes(ScopeNotCompleted).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeOverdue_DB(t *testing.T) {
db := setupDB(t)
past := now.AddDate(0, 0, -5)
seedTask(t, db, &models.Task{Title: "overdue", DueDate: &past})
var tasks []models.Task
db.Scopes(ScopeOverdue(now)).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeDueSoon_DB(t *testing.T) {
db := setupDB(t)
soon := now.AddDate(0, 0, 5)
seedTask(t, db, &models.Task{Title: "due soon", DueDate: &soon})
var tasks []models.Task
db.Scopes(ScopeDueSoon(now, 30)).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeUpcoming_DB(t *testing.T) {
db := setupDB(t)
far := now.AddDate(0, 6, 0)
seedTask(t, db, &models.Task{Title: "upcoming", DueDate: &far})
var tasks []models.Task
db.Scopes(ScopeUpcoming(now, 30)).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeForResidence_DB(t *testing.T) {
db := setupDB(t)
user := testutil.CreateTestUser(t, db, "u3", "u3@test.com", "password123")
res := testutil.CreateTestResidence(t, db, user.ID, "House3")
db.Create(&models.Task{Title: "t1", CreatedByID: user.ID, ResidenceID: res.ID, Version: 1})
var tasks []models.Task
db.Scopes(ScopeForResidence(res.ID)).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeForResidences_DB(t *testing.T) {
db := setupDB(t)
user := testutil.CreateTestUser(t, db, "u4", "u4@test.com", "password123")
res1 := testutil.CreateTestResidence(t, db, user.ID, "H1")
res2 := testutil.CreateTestResidence(t, db, user.ID, "H2")
db.Create(&models.Task{Title: "t1", CreatedByID: user.ID, ResidenceID: res1.ID, Version: 1})
db.Create(&models.Task{Title: "t2", CreatedByID: user.ID, ResidenceID: res2.ID, Version: 1})
var tasks []models.Task
db.Scopes(ScopeForResidences([]uint{res1.ID, res2.ID})).Find(&tasks)
if len(tasks) != 2 {
t.Errorf("len = %d, want 2", len(tasks))
}
}
func TestReExport_ScopeDueInRange_DB(t *testing.T) {
db := setupDB(t)
due := now.AddDate(0, 0, 3)
seedTask(t, db, &models.Task{Title: "in range", DueDate: &due})
var tasks []models.Task
db.Scopes(ScopeDueInRange(now, now.AddDate(0, 0, 7))).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeHasDueDate_DB(t *testing.T) {
db := setupDB(t)
due := now.AddDate(0, 0, 1)
seedTask(t, db, &models.Task{Title: "with due", DueDate: &due})
var tasks []models.Task
db.Scopes(ScopeHasDueDate).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeNoDueDate_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "no due"})
var tasks []models.Task
db.Scopes(ScopeNoDueDate).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeHasCompletions_DB(t *testing.T) {
db := setupDB(t)
user := testutil.CreateTestUser(t, db, "u5", "u5@test.com", "password123")
res := testutil.CreateTestResidence(t, db, user.ID, "H5")
task := &models.Task{Title: "has comp", CreatedByID: user.ID, ResidenceID: res.ID, Version: 1}
db.Create(task)
db.Create(&models.TaskCompletion{TaskID: task.ID, CompletedByID: user.ID, CompletedAt: now})
var tasks []models.Task
db.Scopes(ScopeHasCompletions).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeNoCompletions_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "no comp"})
var tasks []models.Task
db.Scopes(ScopeNoCompletions).Find(&tasks)
if len(tasks) != 1 {
t.Errorf("len = %d, want 1", len(tasks))
}
}
func TestReExport_ScopeOrderByDueDate_DB(t *testing.T) {
db := setupDB(t)
user := testutil.CreateTestUser(t, db, "u6", "u6@test.com", "password123")
res := testutil.CreateTestResidence(t, db, user.ID, "H6")
d1 := now.AddDate(0, 0, 5)
d2 := now.AddDate(0, 0, 1)
db.Create(&models.Task{Title: "later", CreatedByID: user.ID, ResidenceID: res.ID, DueDate: &d1, Version: 1})
db.Create(&models.Task{Title: "sooner", CreatedByID: user.ID, ResidenceID: res.ID, DueDate: &d2, Version: 1})
var tasks []models.Task
db.Scopes(ScopeOrderByDueDate).Find(&tasks)
if len(tasks) != 2 {
t.Fatalf("len = %d", len(tasks))
}
}
func TestReExport_ScopeOrderByPriority_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "prio test"})
var tasks []models.Task
db.Scopes(ScopeOrderByPriority).Find(&tasks)
if len(tasks) < 1 {
t.Error("expected at least 1 task")
}
}
func TestReExport_ScopeOrderByCreatedAt_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "created test"})
var tasks []models.Task
db.Scopes(ScopeOrderByCreatedAt).Find(&tasks)
if len(tasks) < 1 {
t.Error("expected at least 1 task")
}
}
func TestReExport_ScopeKanbanOrder_DB(t *testing.T) {
db := setupDB(t)
seedTask(t, db, &models.Task{Title: "kanban test"})
var tasks []models.Task
db.Scopes(ScopeKanbanOrder).Find(&tasks)
if len(tasks) < 1 {
t.Error("expected at least 1 task")
}
}

View File

@@ -0,0 +1,177 @@
package testutil
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
)
func TestSetupTestDB_Works(t *testing.T) {
db := SetupTestDB(t)
if db == nil {
t.Fatal("expected non-nil db")
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("failed to get sql.DB: %v", err)
}
if err := sqlDB.Ping(); err != nil {
t.Fatalf("ping failed: %v", err)
}
}
func TestSetupTestRouter_HasValidator(t *testing.T) {
e := SetupTestRouter()
if e == nil {
t.Fatal("expected non-nil router")
}
if e.Validator == nil {
t.Error("expected validator to be set")
}
}
func TestCreateTestUser_ReturnsUser(t *testing.T) {
db := SetupTestDB(t)
user := CreateTestUser(t, db, "testuser", "test@example.com", "password123")
if user == nil {
t.Fatal("expected non-nil user")
}
if user.ID == 0 {
t.Error("expected user to have an ID")
}
if user.Username != "testuser" {
t.Errorf("username = %q, want testuser", user.Username)
}
if user.Email != "test@example.com" {
t.Errorf("email = %q, want test@example.com", user.Email)
}
}
func TestSeedLookupData_PopulatesData(t *testing.T) {
db := SetupTestDB(t)
SeedLookupData(t, db)
// Verify residence types seeded
var rtCount int64
db.Table("residence_residencetype").Count(&rtCount)
if rtCount < 4 {
t.Errorf("residence types = %d, want ≥4", rtCount)
}
// Verify task categories seeded
var catCount int64
db.Table("task_taskcategory").Count(&catCount)
if catCount < 4 {
t.Errorf("task categories = %d, want ≥4", catCount)
}
// Verify task priorities seeded
var priCount int64
db.Table("task_taskpriority").Count(&priCount)
if priCount < 4 {
t.Errorf("task priorities = %d, want ≥4", priCount)
}
// Verify task frequencies seeded
var freqCount int64
db.Table("task_taskfrequency").Count(&freqCount)
if freqCount < 3 {
t.Errorf("task frequencies = %d, want ≥3", freqCount)
}
// Verify contractor specialties seeded
var specCount int64
db.Table("task_contractorspecialty").Count(&specCount)
if specCount < 4 {
t.Errorf("contractor specialties = %d, want ≥4", specCount)
}
}
func TestMockAuthMiddleware_SetsUser(t *testing.T) {
db := SetupTestDB(t)
user := CreateTestUser(t, db, "authuser", "auth@example.com", "pass")
e := echo.New()
mw := MockAuthMiddleware(user)
var capturedUser interface{}
handler := mw(func(c echo.Context) error {
capturedUser = c.Get("auth_user")
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if err := handler(c); err != nil {
t.Fatalf("handler error: %v", err)
}
if capturedUser == nil {
t.Fatal("expected auth_user to be set in context")
}
}
func TestCreateTestResidence_ReturnsResidence(t *testing.T) {
db := SetupTestDB(t)
user := CreateTestUser(t, db, "owner", "owner@example.com", "pass")
residence := CreateTestResidence(t, db, user.ID, "Test Home")
if residence == nil {
t.Fatal("expected non-nil residence")
}
if residence.ID == 0 {
t.Error("expected residence to have an ID")
}
if residence.Name != "Test Home" {
t.Errorf("name = %q, want Test Home", residence.Name)
}
}
func TestCreateTestTask_ReturnsTask(t *testing.T) {
db := SetupTestDB(t)
user := CreateTestUser(t, db, "taskowner", "task@example.com", "pass")
residence := CreateTestResidence(t, db, user.ID, "Task Home")
task := CreateTestTask(t, db, residence.ID, user.ID, "Fix the sink")
if task == nil {
t.Fatal("expected non-nil task")
}
if task.ID == 0 {
t.Error("expected task to have an ID")
}
if task.Title != "Fix the sink" {
t.Errorf("title = %q, want Fix the sink", task.Title)
}
}
func TestParseJSON_Valid(t *testing.T) {
data := []byte(`{"name":"test","count":42}`)
result := ParseJSON(t, data)
if result["name"] != "test" {
t.Errorf("name = %v, want test", result["name"])
}
}
func TestParseJSONArray_Valid(t *testing.T) {
data := []byte(`[{"id":1},{"id":2}]`)
result := ParseJSONArray(t, data)
if len(result) != 2 {
t.Errorf("len = %d, want 2", len(result))
}
}
func TestMakeRequestT_ReturnsRecorder(t *testing.T) {
e := SetupTestRouter()
e.GET("/test", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
})
rec := MakeRequestT(t, e, http.MethodGet, "/test", nil, "")
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want 200", rec.Code)
}
}

View File

@@ -1,9 +1,12 @@
package validator
import (
"fmt"
"testing"
govalidator "github.com/go-playground/validator/v10"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidatePasswordComplexity(t *testing.T) {
@@ -113,3 +116,118 @@ func TestFormatMessagePasswordComplexity(t *testing.T) {
t.Errorf("expected tag 'password_complexity', got %q", field.Tag)
}
}
func TestPasswordComplexity_AdditionalCases(t *testing.T) {
cv := NewCustomValidator()
type request struct {
Password string `json:"password" validate:"required,min=8,password_complexity"`
}
tests := []struct {
name string
pw string
valid bool
}{
{"no uppercase no digit", "password", false},
{"no lowercase", "PASSWORD1", false},
{"no digit", "Password", false},
{"too short", "Pass1", false},
{"valid standard", "Password1", true},
{"valid with special chars", "P@ssw0rd", true},
{"spaces with complexity", "Pass 1234", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := request{Password: tc.pw}
err := cv.Validate(r)
if tc.valid {
assert.NoError(t, err, "expected %q to be valid", tc.pw)
} else {
assert.Error(t, err, "expected %q to be invalid", tc.pw)
}
})
}
}
func TestFormatValidationErrors_AllTags(t *testing.T) {
cv := NewCustomValidator()
type allTags struct {
Required string `json:"required" validate:"required"`
Email string `json:"email" validate:"email"`
MinLen string `json:"min_len" validate:"min=5"`
MaxLen string `json:"max_len" validate:"max=3"`
OneOf string `json:"one_of" validate:"oneof=a b c"`
URL string `json:"url" validate:"url"`
}
input := allTags{
Required: "", // fails required
Email: "bad", // fails email
MinLen: "ab", // fails min=5
MaxLen: "abcde", // fails max=3
OneOf: "z", // fails oneof
URL: "nope", // fails url
}
err := cv.Validate(input)
require.Error(t, err)
resp := FormatValidationErrors(err)
require.NotNil(t, resp)
assert.Equal(t, "Validation failed", resp.Error)
expectedMessages := map[string]string{
"required": "This field is required",
"email": "Must be a valid email address",
"min_len": "Must be at least 5 characters",
"max_len": "Must be at most 3 characters",
"one_of": "Must be one of: a b c",
"url": "Must be a valid URL",
}
for field, expectedMsg := range expectedMessages {
fe, ok := resp.Fields[field]
assert.True(t, ok, "expected field %q in error response", field)
if ok {
assert.Equal(t, expectedMsg, fe.Message, "message mismatch for field %q", field)
}
}
}
func TestFormatValidationErrors_NonValidationError(t *testing.T) {
err := fmt.Errorf("some random error")
resp := FormatValidationErrors(err)
require.NotNil(t, resp)
assert.Equal(t, "some random error", resp.Error)
assert.Nil(t, resp.Fields)
}
func TestNewCustomValidator_UsesJSONTagNames(t *testing.T) {
cv := NewCustomValidator()
type request struct {
FirstName string `json:"first_name" validate:"required"`
}
err := cv.Validate(request{})
require.Error(t, err)
resp := FormatValidationErrors(err)
require.NotNil(t, resp)
_, ok := resp.Fields["first_name"]
assert.True(t, ok, "expected JSON tag name 'first_name' in error fields")
}
func TestCustomValidator_Validate_Success(t *testing.T) {
cv := NewCustomValidator()
type request struct {
Name string `json:"name" validate:"required"`
}
err := cv.Validate(request{Name: "test"})
assert.NoError(t, err)
}

View File

@@ -0,0 +1,44 @@
package worker
import "encoding/json"
// Enqueuer defines the interface for enqueuing background email tasks.
type Enqueuer interface {
EnqueueWelcomeEmail(to, firstName, code string) error
EnqueueVerificationEmail(to, firstName, code string) error
EnqueuePasswordResetEmail(to, firstName, code, resetToken string) error
EnqueuePasswordChangedEmail(to, firstName string) error
}
// Verify TaskClient satisfies the interface at compile time.
var _ Enqueuer = (*TaskClient)(nil)
// BuildWelcomeEmailPayload marshals a WelcomeEmailPayload to JSON bytes.
func BuildWelcomeEmailPayload(to, firstName, code string) ([]byte, error) {
return json.Marshal(WelcomeEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
ConfirmationCode: code,
})
}
// BuildVerificationEmailPayload marshals a VerificationEmailPayload to JSON bytes.
func BuildVerificationEmailPayload(to, firstName, code string) ([]byte, error) {
return json.Marshal(VerificationEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
Code: code,
})
}
// BuildPasswordResetEmailPayload marshals a PasswordResetEmailPayload to JSON bytes.
func BuildPasswordResetEmailPayload(to, firstName, code, resetToken string) ([]byte, error) {
return json.Marshal(PasswordResetEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
Code: code,
ResetToken: resetToken,
})
}
// BuildPasswordChangedEmailPayload marshals an EmailPayload to JSON bytes.
func BuildPasswordChangedEmailPayload(to, firstName string) ([]byte, error) {
return json.Marshal(EmailPayload{To: to, FirstName: firstName})
}

View File

@@ -0,0 +1,79 @@
package worker
import (
"encoding/json"
"testing"
)
func TestBuildWelcomeEmailPayload_RoundTrip(t *testing.T) {
data, err := BuildWelcomeEmailPayload("a@b.com", "Alice", "CODE123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p WelcomeEmailPayload
if err := json.Unmarshal(data, &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.To != "a@b.com" || p.FirstName != "Alice" || p.ConfirmationCode != "CODE123" {
t.Errorf("got %+v", p)
}
}
func TestBuildVerificationEmailPayload_RoundTrip(t *testing.T) {
data, err := BuildVerificationEmailPayload("b@c.com", "Bob", "VERIFY456")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p VerificationEmailPayload
if err := json.Unmarshal(data, &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.To != "b@c.com" || p.FirstName != "Bob" || p.Code != "VERIFY456" {
t.Errorf("got %+v", p)
}
}
func TestBuildPasswordResetEmailPayload_RoundTrip(t *testing.T) {
data, err := BuildPasswordResetEmailPayload("c@d.com", "Carol", "RST789", "token-abc")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p PasswordResetEmailPayload
if err := json.Unmarshal(data, &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.To != "c@d.com" || p.FirstName != "Carol" || p.Code != "RST789" || p.ResetToken != "token-abc" {
t.Errorf("got %+v", p)
}
}
func TestBuildPasswordChangedEmailPayload_RoundTrip(t *testing.T) {
data, err := BuildPasswordChangedEmailPayload("d@e.com", "Dave")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p EmailPayload
if err := json.Unmarshal(data, &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.To != "d@e.com" || p.FirstName != "Dave" {
t.Errorf("got %+v", p)
}
}
func TestBuildWelcomeEmailPayload_Fields(t *testing.T) {
data, err := BuildWelcomeEmailPayload("test@example.com", "Test", "ABC")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify raw JSON contains expected keys
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("unmarshal: %v", err)
}
for _, key := range []string{"to", "first_name", "confirmation_code"} {
if _, ok := raw[key]; !ok {
t.Errorf("missing key %q in payload JSON", key)
}
}
}

View File

@@ -30,38 +30,43 @@ const (
// Handler handles background job processing
type Handler struct {
db *gorm.DB
taskRepo *repositories.TaskRepository
residenceRepo *repositories.ResidenceRepository
reminderRepo *repositories.ReminderRepository
notificationRepo *repositories.NotificationRepository
pushClient *push.Client
emailService *services.EmailService
notificationService *services.NotificationService
onboardingService *services.OnboardingEmailService
config *config.Config
db *gorm.DB
taskRepo TaskRepo
residenceRepo ResidenceRepo
reminderRepo ReminderRepo
notificationRepo NotificationRepo
pushClient PushSender
emailService EmailSender
notificationService NotificationSender
onboardingService OnboardingEmailSender
config *config.Config
}
// NewHandler creates a new job handler
func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.EmailService, notificationService *services.NotificationService, cfg *config.Config) *Handler {
// Create onboarding email service
var onboardingService *services.OnboardingEmailService
if emailService != nil {
onboardingService = services.NewOnboardingEmailService(db, emailService, cfg.Server.BaseURL)
h := &Handler{
db: db,
taskRepo: repositories.NewTaskRepository(db),
residenceRepo: repositories.NewResidenceRepository(db),
reminderRepo: repositories.NewReminderRepository(db),
notificationRepo: repositories.NewNotificationRepository(db),
config: cfg,
}
return &Handler{
db: db,
taskRepo: repositories.NewTaskRepository(db),
residenceRepo: repositories.NewResidenceRepository(db),
reminderRepo: repositories.NewReminderRepository(db),
notificationRepo: repositories.NewNotificationRepository(db),
pushClient: pushClient,
emailService: emailService,
notificationService: notificationService,
onboardingService: onboardingService,
config: cfg,
// Assign interface fields only when concrete values are non-nil
// to preserve correct nil checks on the interface values.
if pushClient != nil {
h.pushClient = pushClient
}
if emailService != nil {
h.emailService = emailService
h.onboardingService = services.NewOnboardingEmailService(db, emailService, cfg.Server.BaseURL)
}
if notificationService != nil {
h.notificationService = notificationService
}
return h
}
// HandleDailyDigest processes daily digest notifications with task statistics

View File

@@ -0,0 +1,39 @@
package jobs
import (
"fmt"
"strings"
"github.com/treytartt/honeydue-api/internal/models"
)
// BuildDigestMessage constructs the daily digest notification text.
func BuildDigestMessage(overdueCount, dueThisWeekCount int) (title, body string) {
title = "Daily Task Summary"
if overdueCount > 0 && dueThisWeekCount > 0 {
body = fmt.Sprintf("You have %d overdue task(s) and %d task(s) due this week", overdueCount, dueThisWeekCount)
} else if overdueCount > 0 {
body = fmt.Sprintf("You have %d overdue task(s) that need attention", overdueCount)
} else {
body = fmt.Sprintf("You have %d task(s) due this week", dueThisWeekCount)
}
return
}
// IsOverdueStage checks if a reminder stage string represents overdue.
func IsOverdueStage(stage string) bool {
return strings.HasPrefix(stage, "overdue")
}
// ExtractFrequencyDays gets interval days from a task's frequency.
func ExtractFrequencyDays(t *models.Task) *int {
if t.Frequency != nil && t.Frequency.Days != nil {
days := *t.Frequency.Days
return &days
}
if t.CustomIntervalDays != nil {
days := *t.CustomIntervalDays
return &days
}
return nil
}

View File

@@ -0,0 +1,226 @@
package jobs
import (
"encoding/json"
"testing"
"github.com/treytartt/honeydue-api/internal/models"
)
// --- BuildDigestMessage ---
func TestBuildDigestMessage_BothCounts(t *testing.T) {
title, body := BuildDigestMessage(3, 5)
if title != "Daily Task Summary" {
t.Errorf("title = %q, want %q", title, "Daily Task Summary")
}
want := "You have 3 overdue task(s) and 5 task(s) due this week"
if body != want {
t.Errorf("body = %q, want %q", body, want)
}
}
func TestBuildDigestMessage_OnlyOverdue(t *testing.T) {
_, body := BuildDigestMessage(2, 0)
want := "You have 2 overdue task(s) that need attention"
if body != want {
t.Errorf("body = %q, want %q", body, want)
}
}
func TestBuildDigestMessage_OnlyDueSoon(t *testing.T) {
_, body := BuildDigestMessage(0, 4)
want := "You have 4 task(s) due this week"
if body != want {
t.Errorf("body = %q, want %q", body, want)
}
}
func TestBuildDigestMessage_Title_AlwaysDailyTaskSummary(t *testing.T) {
cases := [][2]int{{1, 1}, {0, 1}, {1, 0}}
for _, c := range cases {
title, _ := BuildDigestMessage(c[0], c[1])
if title != "Daily Task Summary" {
t.Errorf("BuildDigestMessage(%d,%d) title = %q", c[0], c[1], title)
}
}
}
// --- IsOverdueStage ---
func TestIsOverdueStage_Overdue1_True(t *testing.T) {
if !IsOverdueStage("overdue_1") {
t.Error("expected true for overdue_1")
}
}
func TestIsOverdueStage_Overdue14_True(t *testing.T) {
if !IsOverdueStage("overdue_14") {
t.Error("expected true for overdue_14")
}
}
func TestIsOverdueStage_Reminder7d_False(t *testing.T) {
if IsOverdueStage("reminder_7d") {
t.Error("expected false for reminder_7d")
}
}
func TestIsOverdueStage_DayOf_False(t *testing.T) {
if IsOverdueStage("day_of") {
t.Error("expected false for day_of")
}
}
func TestIsOverdueStage_Empty_False(t *testing.T) {
if IsOverdueStage("") {
t.Error("expected false for empty string")
}
}
// --- ExtractFrequencyDays ---
func TestExtractFrequencyDays_WithFrequency(t *testing.T) {
days := 7
task := &models.Task{
Frequency: &models.TaskFrequency{Days: &days},
}
got := ExtractFrequencyDays(task)
if got == nil || *got != 7 {
t.Errorf("got %v, want 7", got)
}
}
func TestExtractFrequencyDays_WithCustomInterval(t *testing.T) {
custom := 14
task := &models.Task{
CustomIntervalDays: &custom,
}
got := ExtractFrequencyDays(task)
if got == nil || *got != 14 {
t.Errorf("got %v, want 14", got)
}
}
func TestExtractFrequencyDays_NilFrequency(t *testing.T) {
task := &models.Task{}
got := ExtractFrequencyDays(task)
if got != nil {
t.Errorf("got %v, want nil", got)
}
}
func TestExtractFrequencyDays_NilDays(t *testing.T) {
task := &models.Task{
Frequency: &models.TaskFrequency{},
}
got := ExtractFrequencyDays(task)
if got != nil {
t.Errorf("got %v, want nil", got)
}
}
// --- Email payload tests ---
func TestEmailPayload_Unmarshal_Valid(t *testing.T) {
data := []byte(`{"to":"a@b.com","subject":"hi","html_body":"<b>hi</b>","text_body":"hi"}`)
var p EmailPayload
if err := json.Unmarshal(data, &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.To != "a@b.com" || p.Subject != "hi" {
t.Errorf("got %+v", p)
}
}
func TestEmailPayload_Unmarshal_Invalid(t *testing.T) {
var p EmailPayload
if err := json.Unmarshal([]byte(`{invalid}`), &p); err == nil {
t.Error("expected error for invalid JSON")
}
}
// --- NewSendEmailTask ---
func TestNewSendEmailTask_ReturnsTask(t *testing.T) {
task, err := NewSendEmailTask("a@b.com", "Subject", "<b>hi</b>", "hi")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if task.Type() != TypeSendEmail {
t.Errorf("task type = %q, want %q", task.Type(), TypeSendEmail)
}
}
func TestNewSendEmailTask_PayloadFields(t *testing.T) {
task, err := NewSendEmailTask("user@example.com", "Welcome", "<h1>Hello</h1>", "Hello")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p EmailPayload
if err := json.Unmarshal(task.Payload(), &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.To != "user@example.com" {
t.Errorf("To = %q, want %q", p.To, "user@example.com")
}
if p.Subject != "Welcome" {
t.Errorf("Subject = %q, want %q", p.Subject, "Welcome")
}
if p.HTMLBody != "<h1>Hello</h1>" {
t.Errorf("HTMLBody = %q, want %q", p.HTMLBody, "<h1>Hello</h1>")
}
if p.TextBody != "Hello" {
t.Errorf("TextBody = %q, want %q", p.TextBody, "Hello")
}
}
// --- NewSendPushTask ---
func TestNewSendPushTask_ReturnsTask(t *testing.T) {
task, err := NewSendPushTask(42, "Title", "Body", map[string]string{"key": "val"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if task.Type() != TypeSendPush {
t.Errorf("task type = %q, want %q", task.Type(), TypeSendPush)
}
}
func TestNewSendPushTask_PayloadFields(t *testing.T) {
data := map[string]string{"action": "open", "id": "123"}
task, err := NewSendPushTask(7, "Alert", "Something happened", data)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p PushPayload
if err := json.Unmarshal(task.Payload(), &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.UserID != 7 {
t.Errorf("UserID = %d, want 7", p.UserID)
}
if p.Title != "Alert" {
t.Errorf("Title = %q, want %q", p.Title, "Alert")
}
if p.Message != "Something happened" {
t.Errorf("Message = %q, want %q", p.Message, "Something happened")
}
if p.Data["action"] != "open" || p.Data["id"] != "123" {
t.Errorf("Data = %v, want map with action=open, id=123", p.Data)
}
}
func TestNewSendPushTask_NilData(t *testing.T) {
task, err := NewSendPushTask(1, "T", "M", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var p PushPayload
if err := json.Unmarshal(task.Payload(), &p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.Data != nil {
t.Errorf("Data = %v, want nil", p.Data)
}
}

View File

@@ -0,0 +1,388 @@
package jobs
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"github.com/hibiken/asynq"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
)
// --- Mock implementations ---
type mockEmailSender struct {
sendFn func(to, subject, htmlBody, textBody string) error
}
func (m *mockEmailSender) SendEmail(to, subject, htmlBody, textBody string) error {
if m.sendFn != nil {
return m.sendFn(to, subject, htmlBody, textBody)
}
return nil
}
type mockPushSender struct {
sendFn func(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error
}
func (m *mockPushSender) SendToAll(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error {
if m.sendFn != nil {
return m.sendFn(ctx, iosTokens, androidTokens, title, message, data)
}
return nil
}
type mockNotificationRepo struct {
findPrefsFn func(userID uint) (*models.NotificationPreference, error)
getTokensFn func(userID uint) ([]string, []string, error)
}
func (m *mockNotificationRepo) FindPreferencesByUser(userID uint) (*models.NotificationPreference, error) {
if m.findPrefsFn != nil {
return m.findPrefsFn(userID)
}
return &models.NotificationPreference{}, nil
}
func (m *mockNotificationRepo) GetActiveTokensForUser(userID uint) ([]string, []string, error) {
if m.getTokensFn != nil {
return m.getTokensFn(userID)
}
return nil, nil, nil
}
type mockReminderRepo struct {
cleanupFn func(daysOld int) (int64, error)
batchFn func(keys []repositories.ReminderKey) (map[int]bool, error)
logFn func(taskID, userID uint, dueDate time.Time, stage models.ReminderStage, notificationID *uint) (*models.TaskReminderLog, error)
}
func (m *mockReminderRepo) HasSentReminderBatch(keys []repositories.ReminderKey) (map[int]bool, error) {
if m.batchFn != nil {
return m.batchFn(keys)
}
return map[int]bool{}, nil
}
func (m *mockReminderRepo) LogReminder(taskID, userID uint, dueDate time.Time, stage models.ReminderStage, notificationID *uint) (*models.TaskReminderLog, error) {
if m.logFn != nil {
return m.logFn(taskID, userID, dueDate, stage, notificationID)
}
return &models.TaskReminderLog{}, nil
}
func (m *mockReminderRepo) CleanupOldLogs(daysOld int) (int64, error) {
if m.cleanupFn != nil {
return m.cleanupFn(daysOld)
}
return 0, nil
}
type mockOnboardingSender struct {
noResFn func() (int, error)
noTasksFn func() (int, error)
}
func (m *mockOnboardingSender) CheckAndSendNoResidenceEmails() (int, error) {
if m.noResFn != nil {
return m.noResFn()
}
return 0, nil
}
func (m *mockOnboardingSender) CheckAndSendNoTasksEmails() (int, error) {
if m.noTasksFn != nil {
return m.noTasksFn()
}
return 0, nil
}
type mockNotificationSender struct {
sendFn func(ctx context.Context, userID uint, notificationType models.NotificationType, task *models.Task) error
}
func (m *mockNotificationSender) CreateAndSendTaskNotification(ctx context.Context, userID uint, notificationType models.NotificationType, task *models.Task) error {
if m.sendFn != nil {
return m.sendFn(ctx, userID, notificationType, task)
}
return nil
}
// --- Helper to build a handler with mocks ---
func newTestHandler(opts ...func(*Handler)) *Handler {
h := &Handler{
config: &config.Config{},
}
for _, opt := range opts {
opt(h)
}
return h
}
func makeTask(taskType string, payload interface{}) *asynq.Task {
data, _ := json.Marshal(payload)
return asynq.NewTask(taskType, data)
}
// --- HandleSendEmail tests ---
func TestHandleSendEmail_Success(t *testing.T) {
var called bool
h := newTestHandler(func(h *Handler) {
h.emailService = &mockEmailSender{
sendFn: func(to, subject, htmlBody, textBody string) error {
called = true
if to != "test@example.com" {
t.Errorf("to = %q, want %q", to, "test@example.com")
}
return nil
},
}
})
task := makeTask(TypeSendEmail, EmailPayload{
To: "test@example.com",
Subject: "Hello",
HTMLBody: "<b>Hi</b>",
TextBody: "Hi",
})
err := h.HandleSendEmail(context.Background(), task)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Error("expected email service to be called")
}
}
func TestHandleSendEmail_InvalidPayload(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.emailService = &mockEmailSender{}
})
task := asynq.NewTask(TypeSendEmail, []byte(`{invalid`))
err := h.HandleSendEmail(context.Background(), task)
if err == nil {
t.Error("expected error for invalid payload")
}
}
func TestHandleSendEmail_SendFails(t *testing.T) {
sendErr := errors.New("SMTP error")
h := newTestHandler(func(h *Handler) {
h.emailService = &mockEmailSender{
sendFn: func(_, _, _, _ string) error { return sendErr },
}
})
task := makeTask(TypeSendEmail, EmailPayload{To: "a@b.com", Subject: "S"})
err := h.HandleSendEmail(context.Background(), task)
if !errors.Is(err, sendErr) {
t.Errorf("err = %v, want %v", err, sendErr)
}
}
func TestHandleSendEmail_NilService_Noop(t *testing.T) {
h := newTestHandler() // emailService is nil
task := makeTask(TypeSendEmail, EmailPayload{To: "a@b.com", Subject: "S"})
err := h.HandleSendEmail(context.Background(), task)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
// --- HandleSendPush tests ---
func TestHandleSendPush_Success(t *testing.T) {
var pushCalled bool
h := newTestHandler(func(h *Handler) {
h.notificationRepo = &mockNotificationRepo{
getTokensFn: func(userID uint) ([]string, []string, error) {
return []string{"ios-token"}, []string{"android-token"}, nil
},
}
h.pushClient = &mockPushSender{
sendFn: func(_ context.Context, ios, android []string, title, msg string, data map[string]string) error {
pushCalled = true
if len(ios) != 1 || ios[0] != "ios-token" {
t.Errorf("ios tokens = %v", ios)
}
return nil
},
}
})
task := makeTask(TypeSendPush, PushPayload{
UserID: 1,
Title: "Alert",
Message: "Hello",
Data: map[string]string{"type": "test"},
})
err := h.HandleSendPush(context.Background(), task)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !pushCalled {
t.Error("expected push client to be called")
}
}
func TestHandleSendPush_InvalidPayload(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.pushClient = &mockPushSender{}
})
task := asynq.NewTask(TypeSendPush, []byte(`{bad`))
err := h.HandleSendPush(context.Background(), task)
if err == nil {
t.Error("expected error for invalid payload")
}
}
func TestHandleSendPush_NoTokens(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.notificationRepo = &mockNotificationRepo{
getTokensFn: func(userID uint) ([]string, []string, error) {
return nil, nil, nil // no tokens
},
}
h.pushClient = &mockPushSender{
sendFn: func(_ context.Context, _, _ []string, _, _ string, _ map[string]string) error {
t.Error("push should not be called when no tokens")
return nil
},
}
})
task := makeTask(TypeSendPush, PushPayload{UserID: 1, Title: "T", Message: "M"})
err := h.HandleSendPush(context.Background(), task)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestHandleSendPush_NilClient_Noop(t *testing.T) {
h := newTestHandler() // pushClient is nil
task := makeTask(TypeSendPush, PushPayload{UserID: 1, Title: "T", Message: "M"})
err := h.HandleSendPush(context.Background(), task)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
// --- HandleOnboardingEmails tests ---
func TestHandleOnboardingEmails_Disabled(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.config.Features.OnboardingEmailsEnabled = false
})
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestHandleOnboardingEmails_NilService(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.config.Features.OnboardingEmailsEnabled = true
// onboardingService is nil
})
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestHandleOnboardingEmails_Success(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.config.Features.OnboardingEmailsEnabled = true
h.onboardingService = &mockOnboardingSender{
noResFn: func() (int, error) { return 2, nil },
noTasksFn: func() (int, error) { return 3, nil },
}
})
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestHandleOnboardingEmails_BothFail(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.config.Features.OnboardingEmailsEnabled = true
h.onboardingService = &mockOnboardingSender{
noResFn: func() (int, error) { return 0, errors.New("fail1") },
noTasksFn: func() (int, error) { return 0, errors.New("fail2") },
}
})
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
if err == nil {
t.Error("expected error when both sub-tasks fail")
}
}
func TestHandleOnboardingEmails_PartialFail(t *testing.T) {
h := newTestHandler(func(h *Handler) {
h.config.Features.OnboardingEmailsEnabled = true
h.onboardingService = &mockOnboardingSender{
noResFn: func() (int, error) { return 0, errors.New("fail") },
noTasksFn: func() (int, error) { return 1, nil },
}
})
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
if err != nil {
t.Fatalf("partial failure should not return error, got: %v", err)
}
}
// --- HandleReminderLogCleanup tests ---
func TestHandleReminderLogCleanup_Success(t *testing.T) {
var calledDays int
h := newTestHandler(func(h *Handler) {
h.reminderRepo = &mockReminderRepo{
cleanupFn: func(daysOld int) (int64, error) {
calledDays = daysOld
return 42, nil
},
}
})
err := h.HandleReminderLogCleanup(context.Background(), asynq.NewTask(TypeReminderLogCleanup, nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calledDays != 90 {
t.Errorf("cleanup called with %d days, want 90", calledDays)
}
}
func TestHandleReminderLogCleanup_Error(t *testing.T) {
cleanupErr := errors.New("db error")
h := newTestHandler(func(h *Handler) {
h.reminderRepo = &mockReminderRepo{
cleanupFn: func(daysOld int) (int64, error) { return 0, cleanupErr },
}
})
err := h.HandleReminderLogCleanup(context.Background(), asynq.NewTask(TypeReminderLogCleanup, nil))
if !errors.Is(err, cleanupErr) {
t.Errorf("err = %v, want %v", err, cleanupErr)
}
}

View File

@@ -0,0 +1,55 @@
package jobs
import (
"context"
"time"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
)
// TaskRepo defines task query operations needed by job handlers.
type TaskRepo interface {
GetOverdueTasks(now time.Time, opts repositories.TaskFilterOptions) ([]models.Task, error)
GetDueSoonTasks(now time.Time, daysThreshold int, opts repositories.TaskFilterOptions) ([]models.Task, error)
GetActiveTasksForUsers(now time.Time, opts repositories.TaskFilterOptions) ([]models.Task, error)
}
// ResidenceRepo defines residence query operations needed by job handlers.
type ResidenceRepo interface {
FindResidenceIDsByUser(userID uint) ([]uint, error)
}
// ReminderRepo defines reminder log operations needed by job handlers.
type ReminderRepo interface {
HasSentReminderBatch(keys []repositories.ReminderKey) (map[int]bool, error)
LogReminder(taskID, userID uint, dueDate time.Time, stage models.ReminderStage, notificationID *uint) (*models.TaskReminderLog, error)
CleanupOldLogs(daysOld int) (int64, error)
}
// NotificationRepo defines notification preference operations needed by job handlers.
type NotificationRepo interface {
FindPreferencesByUser(userID uint) (*models.NotificationPreference, error)
GetActiveTokensForUser(userID uint) ([]string, []string, error)
}
// NotificationSender creates and sends task notifications.
type NotificationSender interface {
CreateAndSendTaskNotification(ctx context.Context, userID uint, notificationType models.NotificationType, task *models.Task) error
}
// PushSender sends push notifications to device tokens.
type PushSender interface {
SendToAll(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error
}
// EmailSender sends emails.
type EmailSender interface {
SendEmail(to, subject, htmlBody, textBody string) error
}
// OnboardingEmailSender sends onboarding campaign emails.
type OnboardingEmailSender interface {
CheckAndSendNoResidenceEmails() (int, error)
CheckAndSendNoTasksEmails() (int, error)
}

View File

@@ -1,8 +1,6 @@
package worker
import (
"encoding/json"
"github.com/hibiken/asynq"
"github.com/rs/zerolog/log"
)
@@ -58,10 +56,7 @@ func (c *TaskClient) Close() error {
// EnqueueWelcomeEmail enqueues a welcome email task
func (c *TaskClient) EnqueueWelcomeEmail(to, firstName, code string) error {
payload, err := json.Marshal(WelcomeEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
ConfirmationCode: code,
})
payload, err := BuildWelcomeEmailPayload(to, firstName, code)
if err != nil {
return err
}
@@ -79,10 +74,7 @@ func (c *TaskClient) EnqueueWelcomeEmail(to, firstName, code string) error {
// EnqueueVerificationEmail enqueues a verification email task
func (c *TaskClient) EnqueueVerificationEmail(to, firstName, code string) error {
payload, err := json.Marshal(VerificationEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
Code: code,
})
payload, err := BuildVerificationEmailPayload(to, firstName, code)
if err != nil {
return err
}
@@ -100,11 +92,7 @@ func (c *TaskClient) EnqueueVerificationEmail(to, firstName, code string) error
// EnqueuePasswordResetEmail enqueues a password reset email task
func (c *TaskClient) EnqueuePasswordResetEmail(to, firstName, code, resetToken string) error {
payload, err := json.Marshal(PasswordResetEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
Code: code,
ResetToken: resetToken,
})
payload, err := BuildPasswordResetEmailPayload(to, firstName, code, resetToken)
if err != nil {
return err
}
@@ -122,7 +110,7 @@ func (c *TaskClient) EnqueuePasswordResetEmail(to, firstName, code, resetToken s
// EnqueuePasswordChangedEmail enqueues a password changed confirmation email
func (c *TaskClient) EnqueuePasswordChangedEmail(to, firstName string) error {
payload, err := json.Marshal(EmailPayload{To: to, FirstName: firstName})
payload, err := BuildPasswordChangedEmailPayload(to, firstName)
if err != nil {
return err
}

View File

@@ -0,0 +1,110 @@
package worker
import (
"encoding/json"
"testing"
)
// --- Payload roundtrip tests ---
func TestWelcomeEmailPayload_MarshalRoundtrip(t *testing.T) {
original := WelcomeEmailPayload{
EmailPayload: EmailPayload{To: "a@b.com", FirstName: "Alice"},
ConfirmationCode: "ABC123",
}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var got WelcomeEmailPayload
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.To != original.To || got.FirstName != original.FirstName || got.ConfirmationCode != original.ConfirmationCode {
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
}
}
func TestVerificationEmailPayload_MarshalRoundtrip(t *testing.T) {
original := VerificationEmailPayload{
EmailPayload: EmailPayload{To: "b@c.com", FirstName: "Bob"},
Code: "999888",
}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var got VerificationEmailPayload
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.To != original.To || got.FirstName != original.FirstName || got.Code != original.Code {
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
}
}
func TestPasswordResetEmailPayload_MarshalRoundtrip(t *testing.T) {
original := PasswordResetEmailPayload{
EmailPayload: EmailPayload{To: "c@d.com", FirstName: "Carol"},
Code: "XYZ",
ResetToken: "tok-abc-123",
}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var got PasswordResetEmailPayload
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.To != original.To || got.FirstName != original.FirstName || got.Code != original.Code || got.ResetToken != original.ResetToken {
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
}
}
func TestPasswordChangedEmailPayload_MarshalRoundtrip(t *testing.T) {
original := EmailPayload{To: "d@e.com", FirstName: "Dave"}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var got EmailPayload
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.To != original.To || got.FirstName != original.FirstName {
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
}
}
// --- Task type constant tests ---
func TestTaskTypeConstants_Unique(t *testing.T) {
types := []string{
TypeWelcomeEmail,
TypeVerificationEmail,
TypePasswordResetEmail,
TypePasswordChangedEmail,
}
seen := make(map[string]bool)
for _, typ := range types {
if seen[typ] {
t.Errorf("duplicate task type: %q", typ)
}
seen[typ] = true
}
}
func TestTaskTypeConstants_EmailPrefix(t *testing.T) {
types := []string{
TypeWelcomeEmail,
TypeVerificationEmail,
TypePasswordResetEmail,
TypePasswordChangedEmail,
}
for _, typ := range types {
if len(typ) < 6 || typ[:6] != "email:" {
t.Errorf("task type %q does not have 'email:' prefix", typ)
}
}
}

123
pkg/utils/logger_test.go Normal file
View File

@@ -0,0 +1,123 @@
package utils
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/assert"
)
func TestInitLogger_Debug(t *testing.T) {
InitLogger(true)
assert.Equal(t, zerolog.DebugLevel, zerolog.GlobalLevel())
}
func TestInitLogger_Production(t *testing.T) {
InitLogger(false)
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
}
func TestInitLoggerWithWriter_AdditionalWriter(t *testing.T) {
var buf bytes.Buffer
InitLoggerWithWriter(false, &buf)
log.Info().Msg("test message")
// The additional writer should have received JSON output
assert.Contains(t, buf.String(), "test message")
}
func TestInitLoggerWithWriter_Debug_AdditionalWriter(t *testing.T) {
var buf bytes.Buffer
InitLoggerWithWriter(true, &buf)
log.Info().Msg("debug test")
// Additional writer receives JSON even in debug mode
assert.Contains(t, buf.String(), "debug test")
}
func TestInitLoggerWithWriter_NilWriter(t *testing.T) {
// Should not panic with nil additional writer
InitLoggerWithWriter(false, nil)
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
}
func TestEchoLogger_ReturnsMiddleware(t *testing.T) {
mw := EchoLogger()
assert.NotNil(t, mw)
// Test that it processes requests without error
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestEchoRecovery_ReturnsMiddleware(t *testing.T) {
mw := EchoRecovery()
assert.NotNil(t, mw)
// Test panic recovery
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/panic/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := mw(func(c echo.Context) error {
panic("test panic")
})
// Should not panic — the middleware recovers
assert.NotPanics(t, func() {
_ = handler(c)
})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Internal server error")
}
func TestEchoLogger_WithQueryParams(t *testing.T) {
mw := EchoLogger()
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/tasks/?limit=10&offset=0", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
assert.NoError(t, err)
}
func TestEchoLogger_ErrorStatus(t *testing.T) {
mw := EchoLogger()
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/notfound/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusNotFound, "not found")
})
err := handler(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, rec.Code)
}