feat(auth): replace hand-rolled auth with Ory Kratos — phase 2 backend
Delegates all credential management (login, register, password reset, email verification, social sign-in) to Ory Kratos. The Go API now acts as a resource server: the new KratosAuth middleware validates sessions against the Kratos whoami endpoint, writes the local User mirror into Echo context, and all existing domain handlers continue working unchanged. Hand-rolled token auth, AuthToken model, apple_auth/ google_auth services, and the auth refresh flow are removed. Tests are updated to use the fake-token middleware pattern so existing integration assertions require no rewrite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,438 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthUserKey is the key used to store the authenticated user in the context
|
||||
AuthUserKey = "auth_user"
|
||||
// AuthTokenKey is the key used to store the token in the context
|
||||
AuthTokenKey = "auth_token"
|
||||
// TokenCacheTTL is the duration to cache tokens in Redis. Tokens are
|
||||
// valid for DefaultTokenExpiryDays (90), and explicit logout invalidates
|
||||
// the cache, so a long TTL here just means most authed requests skip the
|
||||
// auth-token SQL query entirely.
|
||||
TokenCacheTTL = 1 * time.Hour
|
||||
// TokenCachePrefix is the prefix for token cache keys
|
||||
TokenCachePrefix = "auth_token_"
|
||||
// UserCacheTTL is how long full user records are cached in memory to
|
||||
// avoid hitting the database on every authenticated request. Bumped from
|
||||
// 30s — at 30s the trace showed a SELECT auth_user query on most warm
|
||||
// requests because users aren't in cache long enough to hit twice.
|
||||
UserCacheTTL = 5 * time.Minute
|
||||
// UserCacheMaxSize bounds the per-pod in-memory user cache. With ~1KB
|
||||
// per User struct, 5000 entries = ~5MB per pod. Older entries are
|
||||
// evicted LRU before the limit is exceeded.
|
||||
UserCacheMaxSize = 5000
|
||||
|
||||
// DefaultTokenExpiryDays is the default number of days before a token expires.
|
||||
DefaultTokenExpiryDays = 90
|
||||
)
|
||||
|
||||
// AuthMiddleware provides token authentication middleware
|
||||
type AuthMiddleware struct {
|
||||
db *gorm.DB
|
||||
cache *services.CacheService
|
||||
userCache *UserCache
|
||||
tokenExpiryDays int
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware instance
|
||||
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
|
||||
tokenExpiryDays: DefaultTokenExpiryDays,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration
|
||||
func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware {
|
||||
expiryDays := DefaultTokenExpiryDays
|
||||
if cfg != nil && cfg.Security.TokenExpiryDays > 0 {
|
||||
expiryDays = cfg.Security.TokenExpiryDays
|
||||
}
|
||||
return &AuthMiddleware{
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
|
||||
tokenExpiryDays: expiryDays,
|
||||
}
|
||||
}
|
||||
|
||||
// TokenExpiryDuration returns the token expiry duration.
|
||||
func (m *AuthMiddleware) TokenExpiryDuration() time.Duration {
|
||||
return time.Duration(m.tokenExpiryDays) * 24 * time.Hour
|
||||
}
|
||||
|
||||
// isTokenExpired checks if a token's created timestamp indicates expiry.
|
||||
func (m *AuthMiddleware) isTokenExpired(created time.Time) bool {
|
||||
if created.IsZero() {
|
||||
return false // Legacy tokens without created time are not expired
|
||||
}
|
||||
return time.Since(created) > m.TokenExpiryDuration()
|
||||
}
|
||||
|
||||
// TokenAuth returns an Echo middleware that validates token authentication
|
||||
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Extract token from Authorization header
|
||||
token, err := extractToken(c)
|
||||
if err != nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
// Try to get user from cache first (includes expiry check)
|
||||
user, err := m.getUserFromCache(c.Request().Context(), token)
|
||||
if err == nil && user != nil {
|
||||
// Cache hit - set user in context and continue
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Check if the cache indicated token expiry
|
||||
if err != nil && err.Error() == "token expired" {
|
||||
return apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Cache miss - look up token in database
|
||||
user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
||||
return apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
// Check token expiry
|
||||
if m.isTokenExpired(authToken.Created) {
|
||||
log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired")
|
||||
return apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Cache the user ID and token creation time for future requests
|
||||
if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil {
|
||||
log.Warn().Err(cacheErr).Msg("Failed to cache token info")
|
||||
}
|
||||
|
||||
// Set user in context
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it
|
||||
func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
token, err := extractToken(c)
|
||||
if err != nil {
|
||||
// No token or invalid format - continue without user
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try cache first
|
||||
user, err := m.getUserFromCache(c.Request().Context(), token)
|
||||
if err == nil && user != nil {
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try database
|
||||
user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token)
|
||||
if err == nil && !m.isTokenExpired(authToken.Created) {
|
||||
m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created)
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireVerified returns middleware that rejects users whose email is not
|
||||
// verified (audit LIVE-L19). Apply it after TokenAuth to gate sensitive
|
||||
// actions — e.g. generating residence share codes — behind proof that the
|
||||
// account actually controls its email address.
|
||||
func (m *AuthMiddleware) RequireVerified() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
var verified bool
|
||||
err := m.db.WithContext(c.Request().Context()).
|
||||
Model(&models.UserProfile{}).
|
||||
Where("user_id = ?", user.ID).
|
||||
Select("verified").
|
||||
Scan(&verified).Error
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
if !verified {
|
||||
return apperrors.Forbidden("error.email_not_verified")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractToken extracts the token from the Authorization header
|
||||
func extractToken(c echo.Context) (string, error) {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", fmt.Errorf("authorization header required")
|
||||
}
|
||||
|
||||
// Support both "Token xxx" (Django style) and "Bearer xxx" formats
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid authorization header format")
|
||||
}
|
||||
|
||||
scheme := parts[0]
|
||||
token := parts[1]
|
||||
|
||||
if scheme != "Token" && scheme != "Bearer" {
|
||||
return "", fmt.Errorf("invalid authorization scheme: %s", scheme)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("token is empty")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// getUserFromCache tries to get user from Redis cache, then from the
|
||||
// in-memory user cache, before falling back to the database.
|
||||
// Returns a "token expired" error if the cached creation time indicates expiry.
|
||||
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
||||
if m.cache == nil {
|
||||
return nil, fmt.Errorf("cache not available")
|
||||
}
|
||||
|
||||
userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token)
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, fmt.Errorf("token not in cache")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check token expiry from cached creation time
|
||||
if createdUnix > 0 {
|
||||
created := time.Unix(createdUnix, 0)
|
||||
if m.isTokenExpired(created) {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
// Try in-memory user cache first to avoid a DB round-trip
|
||||
if cached := m.userCache.Get(userID); cached != nil {
|
||||
if !cached.IsActive {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
m.userCache.Invalidate(userID)
|
||||
return nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// In-memory cache miss — fetch from database
|
||||
var user models.User
|
||||
if err := m.db.WithContext(ctx).First(&user, userID).Error; err != nil {
|
||||
// User was deleted - invalidate caches
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
|
||||
// Store in in-memory cache for subsequent requests
|
||||
m.userCache.Set(&user)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// getUserFromDatabaseWithToken looks up the token in the database and returns
|
||||
// both the user and the auth token record (for expiry checking). The ctx is
|
||||
// threaded into the GORM session so the SQL span attaches to the request trace.
|
||||
//
|
||||
// Uses a single JOIN query instead of GORM's Preload (which issues 2 SELECTs).
|
||||
// Over a transatlantic link this saves ~110ms RTT per cache miss.
|
||||
func (m *AuthMiddleware) getUserFromDatabaseWithToken(ctx context.Context, token string) (*models.User, *models.AuthToken, error) {
|
||||
// Flat result row: every column from auth_user prefixed `u_`, every
|
||||
// column from user_authtoken left in its native shape. Mapping to two
|
||||
// structs is mechanical so we don't need a struct tag soup.
|
||||
type joinedRow struct {
|
||||
// AuthToken columns
|
||||
Key string `gorm:"column:key"`
|
||||
Created time.Time `gorm:"column:created"`
|
||||
UserID uint `gorm:"column:user_id"`
|
||||
// User columns (prefixed to avoid collision with UserID)
|
||||
UID uint `gorm:"column:u_id"`
|
||||
UUsername string `gorm:"column:u_username"`
|
||||
UEmail string `gorm:"column:u_email"`
|
||||
UFirstName string `gorm:"column:u_first_name"`
|
||||
ULastName string `gorm:"column:u_last_name"`
|
||||
UPassword string `gorm:"column:u_password"`
|
||||
UIsActive bool `gorm:"column:u_is_active"`
|
||||
UIsStaff bool `gorm:"column:u_is_staff"`
|
||||
UIsSuper bool `gorm:"column:u_is_superuser"`
|
||||
UDateJoined time.Time `gorm:"column:u_date_joined"`
|
||||
ULastLogin *time.Time `gorm:"column:u_last_login"`
|
||||
}
|
||||
|
||||
var row joinedRow
|
||||
err := m.db.WithContext(ctx).
|
||||
Table("user_authtoken AS t").
|
||||
Select(`
|
||||
t.key, t.created, t.user_id,
|
||||
u.id AS u_id,
|
||||
u.username AS u_username,
|
||||
u.email AS u_email,
|
||||
u.first_name AS u_first_name,
|
||||
u.last_name AS u_last_name,
|
||||
u.password AS u_password,
|
||||
u.is_active AS u_is_active,
|
||||
u.is_staff AS u_is_staff,
|
||||
u.is_superuser AS u_is_superuser,
|
||||
u.date_joined AS u_date_joined,
|
||||
u.last_login AS u_last_login
|
||||
`).
|
||||
Joins("INNER JOIN auth_user u ON u.id = t.user_id").
|
||||
Where("t.key = ?", models.HashToken(token)). // audit C1: only the hash is stored
|
||||
Limit(1).
|
||||
Scan(&row).Error
|
||||
if err != nil || row.Key == "" {
|
||||
return nil, nil, fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
user := models.User{
|
||||
ID: row.UID,
|
||||
Username: row.UUsername,
|
||||
Email: row.UEmail,
|
||||
FirstName: row.UFirstName,
|
||||
LastName: row.ULastName,
|
||||
Password: row.UPassword,
|
||||
IsActive: row.UIsActive,
|
||||
IsStaff: row.UIsStaff,
|
||||
IsSuperuser: row.UIsSuper,
|
||||
DateJoined: row.UDateJoined,
|
||||
LastLogin: row.ULastLogin,
|
||||
}
|
||||
authToken := models.AuthToken{
|
||||
Key: row.Key,
|
||||
Created: row.Created,
|
||||
UserID: row.UserID,
|
||||
User: user,
|
||||
}
|
||||
|
||||
if !user.IsActive {
|
||||
return nil, nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
|
||||
m.userCache.Set(&user)
|
||||
return &user, &authToken, nil
|
||||
}
|
||||
|
||||
// getUserFromDatabase looks up the token in the database and caches the
|
||||
// resulting user record in memory.
|
||||
// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking.
|
||||
func (m *AuthMiddleware) getUserFromDatabase(ctx context.Context, token string) (*models.User, error) {
|
||||
user, _, err := m.getUserFromDatabaseWithToken(ctx, token)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// cacheTokenInfo caches the user ID and token creation time for a token
|
||||
func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix())
|
||||
}
|
||||
|
||||
// cacheUserID caches the user ID for a token
|
||||
func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.CacheAuthToken(ctx, token, userID)
|
||||
}
|
||||
|
||||
// InvalidateToken removes a token from the cache
|
||||
func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.InvalidateAuthToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetAuthUser retrieves the authenticated user from the Echo context.
|
||||
// Returns nil if the context value is missing or not of the expected type.
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
val := c.Get(AuthUserKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
user, ok := val.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// GetAuthToken retrieves the auth token from the Echo context
|
||||
func GetAuthToken(c echo.Context) string {
|
||||
token := c.Get(AuthTokenKey)
|
||||
if token == nil {
|
||||
return ""
|
||||
}
|
||||
tokenStr, ok := token.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tokenStr
|
||||
}
|
||||
|
||||
// MustGetAuthUser retrieves the authenticated user or returns error with 401
|
||||
func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return nil, apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// truncateToken safely truncates a token string for logging.
|
||||
// Returns at most the first 8 characters followed by "...".
|
||||
func truncateToken(token string) string {
|
||||
if len(token) > 8 {
|
||||
return token[:8] + "..."
|
||||
}
|
||||
return token + "..."
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// setupTestDB creates a temporary in-memory SQLite database with the required
|
||||
// tables for auth middleware tests.
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.AutoMigrate(&models.User{}, &models.AuthToken{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// createTestUserAndToken creates a user and an auth token, then backdates the
|
||||
// token's Created timestamp by the specified number of days.
|
||||
func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) {
|
||||
t.Helper()
|
||||
|
||||
user := &models.User{
|
||||
Username: username,
|
||||
Email: username + "@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
token := &models.AuthToken{
|
||||
UserID: user.ID,
|
||||
}
|
||||
require.NoError(t, db.Create(token).Error)
|
||||
|
||||
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||
token.Created = backdated
|
||||
|
||||
return user, token
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil) // No Redis cache for these tests
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestTokenAuth_AcceptsValidToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify user was set in context
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "valid_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsInvalidToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "bearer_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "bearer_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "scheme_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Basic "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "JustATokenWithNoScheme")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_EmptyToken_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token ")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
user, token := createTestUserAndToken(t, db, "inactive_user", 10)
|
||||
|
||||
// Deactivate the user
|
||||
require.NoError(t, db.Model(user).Update("is_active", false).Error)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
// No Authorization header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "opt_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "opt_user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_opt_user", 91)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
assert.NotNil(t, m)
|
||||
assert.Equal(t, 30, m.tokenExpiryDays)
|
||||
|
||||
// Token at 29 days should be valid
|
||||
_, token := createTestUserAndToken(t, db, "short_expiry_user", 29)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
|
||||
// Token at 31 days should be expired with 30-day config
|
||||
_, token := createTestUserAndToken(t, db, "custom_expired_user", 31)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, nil)
|
||||
assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays)
|
||||
}
|
||||
|
||||
func TestGetAuthToken_ReturnsToken(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, "test-token-value")
|
||||
assert.Equal(t, "test-token-value", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No token set
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, 12345) // Wrong type
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
// Legacy tokens without created time should not be expired
|
||||
assert.False(t, m.isTokenExpired(models.AuthToken{}.Created))
|
||||
}
|
||||
|
||||
func TestInvalidateToken_NilCache_NoError(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil) // nil cache
|
||||
|
||||
err := m.InvalidateToken(nil, "some-token")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/kratos"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthUserKey stores the authenticated *models.User in the echo context.
|
||||
AuthUserKey = "auth_user"
|
||||
// AuthTokenKey stores the raw session credential in the echo context.
|
||||
AuthTokenKey = "auth_token"
|
||||
// authVerifiedKey stores the Kratos email-verified flag in the context.
|
||||
authVerifiedKey = "auth_email_verified"
|
||||
|
||||
// UserCacheTTL / UserCacheMaxSize bound the in-memory local-user cache.
|
||||
UserCacheTTL = 5 * time.Minute
|
||||
UserCacheMaxSize = 5000
|
||||
|
||||
// kratosSessionCacheTTL is how long a validated session is cached in
|
||||
// Redis, so most authed requests skip the Kratos /whoami round trip.
|
||||
kratosSessionCacheTTL = 5 * time.Minute
|
||||
kratosSessionPrefix = "kratos_sess:"
|
||||
)
|
||||
|
||||
// KratosAuth authenticates requests against an Ory Kratos session. It
|
||||
// replaces the hand-rolled token auth: the session is validated via Kratos
|
||||
// /sessions/whoami (Redis-cached), and the matching local auth_user row is
|
||||
// lazily provisioned on first sight of a Kratos identity.
|
||||
type KratosAuth struct {
|
||||
kratos *kratos.Client
|
||||
cache *services.CacheService
|
||||
db *gorm.DB
|
||||
userCache *UserCache
|
||||
}
|
||||
|
||||
// NewKratosAuth builds the Kratos auth middleware.
|
||||
func NewKratosAuth(k *kratos.Client, cache *services.CacheService, db *gorm.DB) *KratosAuth {
|
||||
return &KratosAuth{
|
||||
kratos: k,
|
||||
cache: cache,
|
||||
db: db,
|
||||
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate validates the Kratos session and requires it.
|
||||
func (m *KratosAuth) Authenticate() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user, verified, cred, err := m.resolve(c)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Kratos authentication failed")
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, cred)
|
||||
c.Set(authVerifiedKey, verified)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuthenticate authenticates if a session is present, else continues
|
||||
// unauthenticated.
|
||||
func (m *KratosAuth) OptionalAuthenticate() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if user, verified, cred, err := m.resolve(c); err == nil {
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, cred)
|
||||
c.Set(authVerifiedKey, verified)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireVerified rejects users whose Kratos email address is not verified.
|
||||
// Apply after Authenticate.
|
||||
func (m *KratosAuth) RequireVerified() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if GetAuthUser(c) == nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
if verified, _ := c.Get(authVerifiedKey).(bool); !verified {
|
||||
return apperrors.Forbidden("error.email_not_verified")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolve validates the request's session and returns the local user.
|
||||
func (m *KratosAuth) resolve(c echo.Context) (*models.User, bool, string, error) {
|
||||
token, cookie := extractSession(c)
|
||||
if token == "" && cookie == "" {
|
||||
return nil, false, "", errors.New("no session credential")
|
||||
}
|
||||
cred := token
|
||||
if cred == "" {
|
||||
cred = cookie
|
||||
}
|
||||
ctx := c.Request().Context()
|
||||
|
||||
// Redis cache: kratos_sess:<hash(cred)> -> "<userID>|<0|1>"
|
||||
cacheKey := kratosSessionPrefix + hashCredential(cred)
|
||||
if m.cache != nil {
|
||||
if v, err := m.cache.GetString(ctx, cacheKey); err == nil && v != "" {
|
||||
if user, verified, ok := m.userFromCacheValue(ctx, v); ok {
|
||||
return user, verified, cred, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sess, err := m.kratos.Whoami(ctx, token, cookie)
|
||||
if err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
user, err := m.provision(ctx, sess)
|
||||
if err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
if m.cache != nil {
|
||||
_ = m.cache.SetString(ctx, cacheKey,
|
||||
fmt.Sprintf("%d|%s", user.ID, boolDigit(sess.EmailVerified())), kratosSessionCacheTTL)
|
||||
}
|
||||
return user, sess.EmailVerified(), cred, nil
|
||||
}
|
||||
|
||||
// provision finds the local auth_user row for a Kratos identity, creating it
|
||||
// (and a UserProfile) on first sight. Concurrent first requests are handled
|
||||
// by re-reading after a unique-constraint conflict.
|
||||
func (m *KratosAuth) provision(ctx context.Context, sess *kratos.Session) (*models.User, error) {
|
||||
var user models.User
|
||||
err := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error
|
||||
if err == nil {
|
||||
return &user, nil
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user = models.User{
|
||||
KratosID: sess.Identity.ID,
|
||||
Email: sess.Identity.Traits.Email,
|
||||
Username: sess.Identity.Traits.Email,
|
||||
FirstName: sess.Identity.Traits.Name.First,
|
||||
LastName: sess.Identity.Traits.Name.Last,
|
||||
IsActive: true,
|
||||
DateJoined: time.Now().UTC(),
|
||||
}
|
||||
txErr := m.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&models.UserProfile{
|
||||
UserID: user.ID,
|
||||
Verified: sess.EmailVerified(),
|
||||
}).Error
|
||||
})
|
||||
if txErr != nil {
|
||||
// Likely a concurrent provision of the same identity — re-read.
|
||||
if e := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error; e == nil {
|
||||
return &user, nil
|
||||
}
|
||||
return nil, txErr
|
||||
}
|
||||
log.Info().Str("kratos_id", sess.Identity.ID).Uint("user_id", user.ID).
|
||||
Msg("provisioned local user from Kratos identity")
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// userFromCacheValue resolves a cached "<userID>|<0|1>" value to a user.
|
||||
func (m *KratosAuth) userFromCacheValue(ctx context.Context, v string) (*models.User, bool, bool) {
|
||||
parts := strings.SplitN(v, "|", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, false, false
|
||||
}
|
||||
var id uint
|
||||
if _, err := fmt.Sscanf(parts[0], "%d", &id); err != nil || id == 0 {
|
||||
return nil, false, false
|
||||
}
|
||||
verified := parts[1] == "1"
|
||||
if cached := m.userCache.Get(id); cached != nil {
|
||||
return cached, verified, true
|
||||
}
|
||||
var user models.User
|
||||
if err := m.db.WithContext(ctx).First(&user, id).Error; err != nil {
|
||||
return nil, false, false
|
||||
}
|
||||
m.userCache.Set(&user)
|
||||
return &user, verified, true
|
||||
}
|
||||
|
||||
// extractSession pulls the session credential from the request: the
|
||||
// X-Session-Token header or Authorization bearer (mobile clients), or the
|
||||
// ory_kratos_session cookie (web).
|
||||
func extractSession(c echo.Context) (token, cookie string) {
|
||||
if t := c.Request().Header.Get("X-Session-Token"); t != "" {
|
||||
token = t
|
||||
} else if ah := c.Request().Header.Get("Authorization"); ah != "" {
|
||||
parts := strings.SplitN(ah, " ", 2)
|
||||
if len(parts) == 2 && (parts[0] == "Bearer" || parts[0] == "Token") {
|
||||
token = parts[1]
|
||||
}
|
||||
}
|
||||
if token == "" {
|
||||
if ck := c.Request().Header.Get("Cookie"); strings.Contains(ck, "ory_kratos_session") {
|
||||
cookie = ck
|
||||
}
|
||||
}
|
||||
return token, cookie
|
||||
}
|
||||
|
||||
func hashCredential(cred string) string {
|
||||
sum := sha256.Sum256([]byte(cred))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func boolDigit(b bool) string {
|
||||
if b {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
// truncateToken returns the first 8 characters of a credential followed by
|
||||
// "..." for safe inclusion in log lines.
|
||||
func truncateToken(tok string) string {
|
||||
if len(tok) <= 8 {
|
||||
return tok + "..."
|
||||
}
|
||||
return tok[:8] + "..."
|
||||
}
|
||||
|
||||
// GetAuthUser retrieves the authenticated user from the echo context.
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
user, _ := c.Get(AuthUserKey).(*models.User)
|
||||
return user
|
||||
}
|
||||
|
||||
// GetAuthToken retrieves the session credential from the echo context.
|
||||
func GetAuthToken(c echo.Context) string {
|
||||
tok, _ := c.Get(AuthTokenKey).(string)
|
||||
return tok
|
||||
}
|
||||
|
||||
// MustGetAuthUser retrieves the authenticated user or returns a 401 error.
|
||||
func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return nil, apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
Reference in New Issue
Block a user