Coverage priorities 1-5: test pure functions, extract interfaces, mock-based handler tests
- Priority 1: Test NewSendEmailTask + NewSendPushTask (5 tests) - Priority 2: Test customHTTPErrorHandler — all 15+ branches (21 tests) - Priority 3: Extract Enqueuer interface + payload builders in worker pkg (5 tests) - Priority 4: Extract ClassifyFile/ComputeRelPath in migrate-encrypt (6 tests) - Priority 5: Define Handler interfaces, refactor to accept them, mock-based tests (14 tests) - Fix .gitignore: /worker instead of worker to stop ignoring internal/worker/ Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
163
internal/middleware/admin_auth_test.go
Normal file
163
internal/middleware/admin_auth_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestAdminAuth_NoHeader_Returns401(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Authorization required")
|
||||
}
|
||||
|
||||
func TestAdminAuth_InvalidToken_Returns401(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-jwt-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid token")
|
||||
}
|
||||
|
||||
func TestAdminAuth_TokenSchemeOnly_Returns401(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
// "Token" scheme is not supported for admin auth, only "Bearer"
|
||||
req.Header.Set("Authorization", "Token some-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_NoAdmin_Returns401(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No admin in context
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_WrongType_Returns401(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Wrong type in context
|
||||
c.Set(AdminUserKey, "not-an-admin")
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_NonSuperAdmin_Returns403(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Regular admin (not super admin)
|
||||
admin := &models.AdminUser{
|
||||
Email: "admin@test.com",
|
||||
IsActive: true,
|
||||
Role: models.AdminRoleAdmin,
|
||||
}
|
||||
c.Set(AdminUserKey, admin)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Super admin privileges required")
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_SuperAdmin_Passes(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
admin := &models.AdminUser{
|
||||
Email: "superadmin@test.com",
|
||||
IsActive: true,
|
||||
Role: models.AdminRoleSuperAdmin,
|
||||
}
|
||||
c.Set(AdminUserKey, admin)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
@@ -41,7 +41,7 @@ func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays
|
||||
Email: username + "@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("password123"))
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
token := &models.AuthToken{
|
||||
|
||||
337
internal/middleware/auth_test.go
Normal file
337
internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "bearer_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "bearer_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "scheme_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Basic "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "JustATokenWithNoScheme")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_EmptyToken_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token ")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
user, token := createTestUserAndToken(t, db, "inactive_user", 10)
|
||||
|
||||
// Deactivate the user
|
||||
require.NoError(t, db.Model(user).Update("is_active", false).Error)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
// No Authorization header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "opt_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "opt_user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_opt_user", 91)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
assert.NotNil(t, m)
|
||||
assert.Equal(t, 30, m.tokenExpiryDays)
|
||||
|
||||
// Token at 29 days should be valid
|
||||
_, token := createTestUserAndToken(t, db, "short_expiry_user", 29)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
|
||||
// Token at 31 days should be expired with 30-day config
|
||||
_, token := createTestUserAndToken(t, db, "custom_expired_user", 31)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, nil)
|
||||
assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays)
|
||||
}
|
||||
|
||||
func TestGetAuthToken_ReturnsToken(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, "test-token-value")
|
||||
assert.Equal(t, "test-token-value", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No token set
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, 12345) // Wrong type
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
// Legacy tokens without created time should not be expired
|
||||
assert.False(t, m.isTokenExpired(models.AuthToken{}.Created))
|
||||
}
|
||||
|
||||
func TestInvalidateToken_NilCache_NoError(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil) // nil cache
|
||||
|
||||
err := m.InvalidateToken(nil, "some-token")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
93
internal/middleware/host_check_test.go
Normal file
93
internal/middleware/host_check_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostCheck_AllowedHost_Passes(t *testing.T) {
|
||||
mw := HostCheck([]string{"api.example.com", "localhost:8000"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "api.example.com"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHostCheck_DisallowedHost_Returns403(t *testing.T) {
|
||||
mw := HostCheck([]string{"api.example.com"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "evil.example.com"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Forbidden", response["error"])
|
||||
}
|
||||
|
||||
func TestHostCheck_EmptyAllowedHosts_AllPass(t *testing.T) {
|
||||
mw := HostCheck([]string{})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "any-host.example.com"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHostCheck_LocalhostWithPort_Passes(t *testing.T) {
|
||||
mw := HostCheck([]string{"localhost:8000"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "localhost:8000"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHostCheck_LocalhostWithoutPort_Denied(t *testing.T) {
|
||||
// Only "localhost:8000" allowed, not plain "localhost"
|
||||
mw := HostCheck([]string{"localhost:8000"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "localhost"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
}
|
||||
103
internal/middleware/logger_test.go
Normal file
103
internal/middleware/logger_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestStructuredLogger_Passes_Request(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "ok", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestStructuredLogger_WithUser(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
user := &models.User{Username: "loguser"}
|
||||
user.ID = 42
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Set(AuthUserKey, user)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestStructuredLogger_WithRequestID(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Set(ContextKeyRequestID, "test-request-id-123")
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestStructuredLogger_ErrorStatus(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusInternalServerError, "error")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestStructuredLogger_ClientError(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusBadRequest, "bad request")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
222
internal/middleware/timezone_test.go
Normal file
222
internal/middleware/timezone_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTimezoneMiddleware_IANATimezone(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "America/New_York")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "America/New_York", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_UTCOffset(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
// Just verify it's not UTC (if offset is non-zero)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "-05:00")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "-05:00", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_NoHeader_DefaultsToUTC(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
// No X-Timezone header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "UTC", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_InvalidTimezone_DefaultsToUTC(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "Invalid/Timezone")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "UTC", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_SetsUserNow(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
|
||||
var capturedNow time.Time
|
||||
handler := mw(func(c echo.Context) error {
|
||||
capturedNow = GetUserNow(c)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "America/Chicago")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The "now" should be the start of the day in the user's timezone
|
||||
assert.Equal(t, 0, capturedNow.Hour())
|
||||
assert.Equal(t, 0, capturedNow.Minute())
|
||||
assert.Equal(t, 0, capturedNow.Second())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_SetsTimezoneName(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
name := GetTimezoneName(c)
|
||||
return c.String(http.StatusOK, name)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "Europe/London")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Europe/London", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestGetUserTimezone_NotSet_ReturnsUTC(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No timezone set in context
|
||||
loc := GetUserTimezone(c)
|
||||
assert.Equal(t, time.UTC, loc)
|
||||
}
|
||||
|
||||
func TestGetUserTimezone_WrongType_ReturnsUTC(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(TimezoneKey, "not-a-location")
|
||||
loc := GetUserTimezone(c)
|
||||
assert.Equal(t, time.UTC, loc)
|
||||
}
|
||||
|
||||
func TestGetUserNow_NotSet_ReturnsUTCNow(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
before := time.Now().UTC()
|
||||
now := GetUserNow(c)
|
||||
after := time.Now().UTC()
|
||||
|
||||
assert.True(t, !now.Before(before.Add(-time.Second)), "now should be roughly after before")
|
||||
assert.True(t, !now.After(after.Add(time.Second)), "now should be roughly before after")
|
||||
}
|
||||
|
||||
func TestGetUserNow_WrongType_ReturnsUTCNow(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(UserNowKey, "not-a-time")
|
||||
now := GetUserNow(c)
|
||||
assert.NotNil(t, now)
|
||||
}
|
||||
|
||||
func TestIsTimezoneChanged_NoChange_ReturnsFalse(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(TimezoneChangedKey, false)
|
||||
assert.False(t, IsTimezoneChanged(c))
|
||||
}
|
||||
|
||||
func TestIsTimezoneChanged_Changed_ReturnsTrue(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(TimezoneChangedKey, true)
|
||||
assert.True(t, IsTimezoneChanged(c))
|
||||
}
|
||||
|
||||
func TestIsTimezoneChanged_NotSet_ReturnsFalse(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Not set at all
|
||||
assert.False(t, IsTimezoneChanged(c))
|
||||
}
|
||||
|
||||
func TestParseTimezone_UTCOffsetWithoutColon(t *testing.T) {
|
||||
loc := parseTimezone("-0800")
|
||||
assert.NotEqual(t, time.UTC, loc)
|
||||
assert.Equal(t, "-0800", loc.String())
|
||||
}
|
||||
|
||||
func TestParseTimezone_PositiveOffset(t *testing.T) {
|
||||
loc := parseTimezone("+05:30")
|
||||
assert.NotEqual(t, time.UTC, loc)
|
||||
assert.Equal(t, "+05:30", loc.String())
|
||||
}
|
||||
|
||||
func TestParseTimezone_UTC(t *testing.T) {
|
||||
loc := parseTimezone("UTC")
|
||||
assert.Equal(t, time.UTC, loc)
|
||||
}
|
||||
186
internal/middleware/user_cache_test.go
Normal file
186
internal/middleware/user_cache_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestUserCache_SetAndGet(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "testuser", Email: "test@test.com"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
assert.Equal(t, "testuser", cached.Username)
|
||||
assert.Equal(t, "test@test.com", cached.Email)
|
||||
}
|
||||
|
||||
func TestUserCache_GetNonExistent_ReturnsNil(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
cached := cache.Get(999)
|
||||
assert.Nil(t, cached)
|
||||
}
|
||||
|
||||
func TestUserCache_Expired_ReturnsNil(t *testing.T) {
|
||||
// Very short TTL
|
||||
cache := NewUserCache(1 * time.Millisecond)
|
||||
|
||||
user := &models.User{Username: "expiring_user"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Wait for expiry
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
cached := cache.Get(1)
|
||||
assert.Nil(t, cached, "expired entry should return nil")
|
||||
}
|
||||
|
||||
func TestUserCache_Invalidate(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "to_invalidate"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Verify it's cached
|
||||
require.NotNil(t, cache.Get(1))
|
||||
|
||||
// Invalidate
|
||||
cache.Invalidate(1)
|
||||
|
||||
// Should be gone
|
||||
assert.Nil(t, cache.Get(1))
|
||||
}
|
||||
|
||||
func TestUserCache_ReturnsCopy_NotOriginal(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "original"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Modify the returned copy
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
cached.Username = "modified"
|
||||
|
||||
// Original cache entry should be unaffected
|
||||
cached2 := cache.Get(1)
|
||||
require.NotNil(t, cached2)
|
||||
assert.Equal(t, "original", cached2.Username, "cache should return a copy, not the original")
|
||||
}
|
||||
|
||||
func TestUserCache_SetCopiesInput(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "original"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Modify the input after setting
|
||||
user.Username = "modified_after_set"
|
||||
|
||||
// Cache should still have the original value
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
assert.Equal(t, "original", cached.Username, "cache should store a copy of the input")
|
||||
}
|
||||
|
||||
func TestUserCache_MultipleUsers(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user1 := &models.User{Username: "user1"}
|
||||
user1.ID = 1
|
||||
user2 := &models.User{Username: "user2"}
|
||||
user2.ID = 2
|
||||
|
||||
cache.Set(user1)
|
||||
cache.Set(user2)
|
||||
|
||||
cached1 := cache.Get(1)
|
||||
cached2 := cache.Get(2)
|
||||
|
||||
require.NotNil(t, cached1)
|
||||
require.NotNil(t, cached2)
|
||||
assert.Equal(t, "user1", cached1.Username)
|
||||
assert.Equal(t, "user2", cached2.Username)
|
||||
}
|
||||
|
||||
func TestUserCache_OverwriteEntry(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "original"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Overwrite with new data
|
||||
updated := &models.User{Username: "updated"}
|
||||
updated.ID = 1
|
||||
|
||||
cache.Set(updated)
|
||||
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
assert.Equal(t, "updated", cached.Username)
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_NewEntry(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
// First call should return false (not cached yet)
|
||||
unchanged := tc.GetAndCompare(1, "America/New_York")
|
||||
assert.False(t, unchanged, "first call should indicate a change")
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_SameValue(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
// First call sets the value
|
||||
tc.GetAndCompare(1, "America/New_York")
|
||||
|
||||
// Second call with same value should return true (unchanged)
|
||||
unchanged := tc.GetAndCompare(1, "America/New_York")
|
||||
assert.True(t, unchanged, "same value should indicate no change")
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_DifferentValue(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
// Set initial value
|
||||
tc.GetAndCompare(1, "America/New_York")
|
||||
|
||||
// Update to different value
|
||||
unchanged := tc.GetAndCompare(1, "America/Chicago")
|
||||
assert.False(t, unchanged, "different value should indicate a change")
|
||||
|
||||
// Now the new value is cached
|
||||
unchanged = tc.GetAndCompare(1, "America/Chicago")
|
||||
assert.True(t, unchanged, "same value should indicate no change")
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_DifferentUsers(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
tc.GetAndCompare(1, "America/New_York")
|
||||
tc.GetAndCompare(2, "Europe/London")
|
||||
|
||||
assert.True(t, tc.GetAndCompare(1, "America/New_York"))
|
||||
assert.True(t, tc.GetAndCompare(2, "Europe/London"))
|
||||
assert.False(t, tc.GetAndCompare(1, "Europe/London"))
|
||||
}
|
||||
Reference in New Issue
Block a user