package middleware import ( "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHostCheck_AllowedHost_Passes(t *testing.T) { mw := HostCheck([]string{"api.example.com", "localhost:8000"}) e := echo.New() handler := mw(okHandler) req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) req.Host = "api.example.com" rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) } func TestHostCheck_DisallowedHost_Returns403(t *testing.T) { mw := HostCheck([]string{"api.example.com"}) e := echo.New() handler := mw(okHandler) req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) req.Host = "evil.example.com" rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusForbidden, rec.Code) var response map[string]interface{} err = json.Unmarshal(rec.Body.Bytes(), &response) require.NoError(t, err) assert.Equal(t, "Forbidden", response["error"]) } func TestHostCheck_EmptyAllowedHosts_AllPass(t *testing.T) { mw := HostCheck([]string{}) e := echo.New() handler := mw(okHandler) req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) req.Host = "any-host.example.com" rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) } func TestHostCheck_LocalhostWithPort_Passes(t *testing.T) { mw := HostCheck([]string{"localhost:8000"}) e := echo.New() handler := mw(okHandler) req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) req.Host = "localhost:8000" rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) } func TestHostCheck_LocalhostWithoutPort_Denied(t *testing.T) { // Only "localhost:8000" allowed, not plain "localhost" mw := HostCheck([]string{"localhost:8000"}) e := echo.New() handler := mw(okHandler) req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) req.Host = "localhost" rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusForbidden, rec.Code) }