- Priority 1: Test NewSendEmailTask + NewSendPushTask (5 tests) - Priority 2: Test customHTTPErrorHandler — all 15+ branches (21 tests) - Priority 3: Extract Enqueuer interface + payload builders in worker pkg (5 tests) - Priority 4: Extract ClassifyFile/ComputeRelPath in migrate-encrypt (6 tests) - Priority 5: Define Handler interfaces, refactor to accept them, mock-based tests (14 tests) - Fix .gitignore: /worker instead of worker to stop ignoring internal/worker/ Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
94 lines
2.3 KiB
Go
94 lines
2.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestHostCheck_AllowedHost_Passes(t *testing.T) {
|
|
mw := HostCheck([]string{"api.example.com", "localhost:8000"})
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
|
req.Host = "api.example.com"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
}
|
|
|
|
func TestHostCheck_DisallowedHost_Returns403(t *testing.T) {
|
|
mw := HostCheck([]string{"api.example.com"})
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
|
req.Host = "evil.example.com"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
|
|
|
var response map[string]interface{}
|
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "Forbidden", response["error"])
|
|
}
|
|
|
|
func TestHostCheck_EmptyAllowedHosts_AllPass(t *testing.T) {
|
|
mw := HostCheck([]string{})
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
|
req.Host = "any-host.example.com"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
}
|
|
|
|
func TestHostCheck_LocalhostWithPort_Passes(t *testing.T) {
|
|
mw := HostCheck([]string{"localhost:8000"})
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
|
req.Host = "localhost:8000"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
}
|
|
|
|
func TestHostCheck_LocalhostWithoutPort_Denied(t *testing.T) {
|
|
// Only "localhost:8000" allowed, not plain "localhost"
|
|
mw := HostCheck([]string{"localhost:8000"})
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
|
req.Host = "localhost"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
|
}
|