Harden API security: input validation, safe auth extraction, new tests, and deploy config
Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -35,7 +35,9 @@ func AdminAuthMiddleware(cfg *config.Config, adminRepo *repositories.AdminReposi
|
||||
return func(c echo.Context) error {
|
||||
var tokenString string
|
||||
|
||||
// Get token from Authorization header
|
||||
// Get token from Authorization header only.
|
||||
// Query parameter authentication is intentionally not supported
|
||||
// because tokens in URLs leak into server logs and browser history.
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
// Check Bearer prefix
|
||||
@@ -45,11 +47,6 @@ func AdminAuthMiddleware(cfg *config.Config, adminRepo *repositories.AdminReposi
|
||||
}
|
||||
}
|
||||
|
||||
// If no header token, check query parameter (for WebSocket connections)
|
||||
if tokenString == "" {
|
||||
tokenString = c.QueryParam("token")
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Authorization required"})
|
||||
}
|
||||
@@ -121,7 +118,10 @@ func RequireSuperAdmin() echo.MiddlewareFunc {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Admin authentication required"})
|
||||
}
|
||||
|
||||
adminUser := admin.(*models.AdminUser)
|
||||
adminUser, ok := admin.(*models.AdminUser)
|
||||
if !ok {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Admin authentication required"})
|
||||
}
|
||||
if !adminUser.IsSuperAdmin() {
|
||||
return c.JSON(http.StatusForbidden, map[string]interface{}{"error": "Super admin privileges required"})
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
// Cache miss - look up token in database
|
||||
user, err = m.getUserFromDatabase(token)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("token", token[:8]+"...").Msg("Token authentication failed")
|
||||
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
||||
return apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
@@ -200,13 +200,18 @@ func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) erro
|
||||
return m.cache.InvalidateAuthToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetAuthUser retrieves the authenticated user from the Echo context
|
||||
// GetAuthUser retrieves the authenticated user from the Echo context.
|
||||
// Returns nil if the context value is missing or not of the expected type.
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
user := c.Get(AuthUserKey)
|
||||
if user == nil {
|
||||
val := c.Get(AuthUserKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return user.(*models.User)
|
||||
user, ok := val.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// GetAuthToken retrieves the auth token from the Echo context
|
||||
@@ -226,3 +231,12 @@ func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// truncateToken safely truncates a token string for logging.
|
||||
// Returns at most the first 8 characters followed by "...".
|
||||
func truncateToken(token string) string {
|
||||
if len(token) > 8 {
|
||||
return token[:8] + "..."
|
||||
}
|
||||
return token + "..."
|
||||
}
|
||||
|
||||
119
internal/middleware/auth_safety_test.go
Normal file
119
internal/middleware/auth_safety_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
)
|
||||
|
||||
func TestGetAuthUser_NilContext_ReturnsNil(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No user set in context
|
||||
user := GetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
func TestGetAuthUser_WrongType_ReturnsNil(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Set wrong type in context — should NOT panic
|
||||
c.Set(AuthUserKey, "not-a-user")
|
||||
user := GetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
func TestGetAuthUser_ValidUser_ReturnsUser(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
expected := &models.User{Username: "testuser"}
|
||||
c.Set(AuthUserKey, expected)
|
||||
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "testuser", user.Username)
|
||||
}
|
||||
|
||||
func TestMustGetAuthUser_Nil_Returns401(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
user, err := MustGetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMustGetAuthUser_WrongType_Returns401(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthUserKey, 12345)
|
||||
user, err := MustGetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTokenTruncation_ShortToken_NoPanic(t *testing.T) {
|
||||
// Ensure truncateToken does not panic on short tokens
|
||||
assert.NotPanics(t, func() {
|
||||
result := truncateToken("ab")
|
||||
assert.Equal(t, "ab...", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenTruncation_EmptyToken_NoPanic(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
result := truncateToken("")
|
||||
assert.Equal(t, "...", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenTruncation_LongToken_Truncated(t *testing.T) {
|
||||
result := truncateToken("abcdefghijklmnop")
|
||||
assert.Equal(t, "abcdefgh...", result)
|
||||
}
|
||||
|
||||
func TestAdminAuth_QueryParamToken_Rejected(t *testing.T) {
|
||||
// SEC-20: Admin JWT via query parameter must be rejected.
|
||||
// Tokens in URLs leak into server logs and browser history.
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "should not reach here")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
|
||||
// Request with token only in query param, no Authorization header
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test?token=some-jwt-token", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err) // handler writes JSON directly, no Echo error
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "query param token must be rejected")
|
||||
assert.Contains(t, rec.Body.String(), "Authorization required")
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// validRequestID matches alphanumeric characters and hyphens, 1-64 chars.
|
||||
var validRequestID = regexp.MustCompile(`^[a-zA-Z0-9\-]{1,64}$`)
|
||||
|
||||
const (
|
||||
// HeaderXRequestID is the header key for request correlation IDs
|
||||
HeaderXRequestID = "X-Request-ID"
|
||||
@@ -17,9 +22,11 @@ const (
|
||||
func RequestIDMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Use existing request ID from header if present, otherwise generate one
|
||||
// Use existing request ID from header if present and valid, otherwise generate one.
|
||||
// Sanitize to alphanumeric + hyphens only (max 64 chars) to prevent
|
||||
// log injection via control characters or overly long values.
|
||||
reqID := c.Request().Header.Get(HeaderXRequestID)
|
||||
if reqID == "" {
|
||||
if reqID == "" || !validRequestID.MatchString(reqID) {
|
||||
reqID = uuid.New().String()
|
||||
}
|
||||
|
||||
|
||||
125
internal/middleware/request_id_test.go
Normal file
125
internal/middleware/request_id_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRequestID_ValidID_Preserved(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, "abc-123-def")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "abc-123-def", rec.Body.String())
|
||||
assert.Equal(t, "abc-123-def", rec.Header().Get(HeaderXRequestID))
|
||||
}
|
||||
|
||||
func TestRequestID_Empty_GeneratesNew(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// No X-Request-ID header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
// Should be a UUID (36 chars: 8-4-4-4-12)
|
||||
assert.Len(t, rec.Body.String(), 36)
|
||||
}
|
||||
|
||||
func TestRequestID_ControlChars_Sanitized(t *testing.T) {
|
||||
// SEC-29: Client-supplied X-Request-ID with control characters must be rejected.
|
||||
tests := []struct {
|
||||
name string
|
||||
inputID string
|
||||
}{
|
||||
{"newline injection", "abc\ndef"},
|
||||
{"carriage return", "abc\rdef"},
|
||||
{"null byte", "abc\x00def"},
|
||||
{"tab character", "abc\tdef"},
|
||||
{"html tags", "abc<script>alert(1)</script>"},
|
||||
{"spaces", "abc def"},
|
||||
{"semicolons", "abc;def"},
|
||||
{"unicode", "abc\u200bdef"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, tt.inputID)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
// The malicious ID should be replaced with a generated UUID
|
||||
assert.NotEqual(t, tt.inputID, rec.Body.String(),
|
||||
"control chars should be rejected, got original ID back")
|
||||
assert.Len(t, rec.Body.String(), 36, "should be a generated UUID")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_TooLong_Sanitized(t *testing.T) {
|
||||
// SEC-29: X-Request-ID longer than 64 chars should be rejected.
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
longID := strings.Repeat("a", 65)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, longID)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, longID, rec.Body.String(), "overly long ID should be replaced")
|
||||
assert.Len(t, rec.Body.String(), 36, "should be a generated UUID")
|
||||
}
|
||||
|
||||
func TestRequestID_MaxLength_Accepted(t *testing.T) {
|
||||
// Exactly 64 chars of valid characters should be accepted
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
maxID := strings.Repeat("a", 64)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, maxID)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxID, rec.Body.String(), "64-char valid ID should be accepted")
|
||||
}
|
||||
19
internal/middleware/sanitize.go
Normal file
19
internal/middleware/sanitize.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import "strings"
|
||||
|
||||
// SanitizeSortColumn validates a user-supplied sort column against an allowlist.
|
||||
// Returns defaultCol if the input is empty or not in the allowlist.
|
||||
// This prevents SQL injection via ORDER BY clauses.
|
||||
func SanitizeSortColumn(input string, allowedCols []string, defaultCol string) string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return defaultCol
|
||||
}
|
||||
for _, col := range allowedCols {
|
||||
if strings.EqualFold(input, col) {
|
||||
return col
|
||||
}
|
||||
}
|
||||
return defaultCol
|
||||
}
|
||||
59
internal/middleware/sanitize_test.go
Normal file
59
internal/middleware/sanitize_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSanitizeSortColumn_AllowedColumn_Passes(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("created_at", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_CaseInsensitive(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("Created_At", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_SQLInjection_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"drop table", "created_at; DROP TABLE auth_user; --"},
|
||||
{"union select", "name UNION SELECT * FROM auth_user"},
|
||||
{"or 1=1", "name OR 1=1"},
|
||||
{"semicolon", "created_at;"},
|
||||
{"subquery", "(SELECT password FROM auth_user LIMIT 1)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeSortColumn(tt.input, allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result, "SQL injection attempt should return default")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_Empty_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_Whitespace_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn(" ", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_UnknownColumn_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("nonexistent_column", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
@@ -79,22 +79,30 @@ func parseTimezone(tz string) *time.Location {
|
||||
}
|
||||
|
||||
// GetUserTimezone retrieves the user's timezone from the Echo context.
|
||||
// Returns UTC if not set.
|
||||
// Returns UTC if not set or if the stored value is not a *time.Location.
|
||||
func GetUserTimezone(c echo.Context) *time.Location {
|
||||
loc := c.Get(TimezoneKey)
|
||||
if loc == nil {
|
||||
val := c.Get(TimezoneKey)
|
||||
if val == nil {
|
||||
return time.UTC
|
||||
}
|
||||
return loc.(*time.Location)
|
||||
loc, ok := val.(*time.Location)
|
||||
if !ok {
|
||||
return time.UTC
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
// GetUserNow retrieves the timezone-aware "now" time from the Echo context.
|
||||
// This represents the start of the current day in the user's timezone.
|
||||
// Returns time.Now().UTC() if not set.
|
||||
// Returns time.Now().UTC() if not set or if the stored value is not a time.Time.
|
||||
func GetUserNow(c echo.Context) time.Time {
|
||||
now := c.Get(UserNowKey)
|
||||
if now == nil {
|
||||
val := c.Get(UserNowKey)
|
||||
if val == nil {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return now.(time.Time)
|
||||
now, ok := val.(time.Time)
|
||||
if !ok {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return now
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user