Files
honeyDueAPI/internal/middleware/host_check_test.go
Trey T bec880886b Coverage priorities 1-5: test pure functions, extract interfaces, mock-based handler tests
- Priority 1: Test NewSendEmailTask + NewSendPushTask (5 tests)
- Priority 2: Test customHTTPErrorHandler — all 15+ branches (21 tests)
- Priority 3: Extract Enqueuer interface + payload builders in worker pkg (5 tests)
- Priority 4: Extract ClassifyFile/ComputeRelPath in migrate-encrypt (6 tests)
- Priority 5: Define Handler interfaces, refactor to accept them, mock-based tests (14 tests)
- Fix .gitignore: /worker instead of worker to stop ignoring internal/worker/

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 20:30:09 -05:00

94 lines
2.3 KiB
Go

package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostCheck_AllowedHost_Passes(t *testing.T) {
mw := HostCheck([]string{"api.example.com", "localhost:8000"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "api.example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestHostCheck_DisallowedHost_Returns403(t *testing.T) {
mw := HostCheck([]string{"api.example.com"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "evil.example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, rec.Code)
var response map[string]interface{}
err = json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "Forbidden", response["error"])
}
func TestHostCheck_EmptyAllowedHosts_AllPass(t *testing.T) {
mw := HostCheck([]string{})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "any-host.example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestHostCheck_LocalhostWithPort_Passes(t *testing.T) {
mw := HostCheck([]string{"localhost:8000"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "localhost:8000"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestHostCheck_LocalhostWithoutPort_Denied(t *testing.T) {
// Only "localhost:8000" allowed, not plain "localhost"
mw := HostCheck([]string{"localhost:8000"})
e := echo.New()
handler := mw(okHandler)
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Host = "localhost"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, rec.Code)
}