Initial project setup - Phases 1-3 complete

This commit is contained in:
Trey t
2026-04-06 11:28:40 -05:00
commit c77e506db5
293 changed files with 14233 additions and 0 deletions

View File

@@ -0,0 +1,290 @@
import Foundation
import X509
import Crypto
import SwiftASN1
import NIOSSL
#if canImport(UIKit)
import UIKit
#endif
/// Manages root CA generation, leaf certificate signing, and an LRU certificate cache.
public final class CertificateManager: @unchecked Sendable {
public static let shared = CertificateManager()
private let lock = NSLock()
private var rootCAKey: P256.Signing.PrivateKey?
private var rootCACert: Certificate?
private var rootCANIOSSL: NIOSSLCertificate?
// LRU cache for generated leaf certificates
private var certCache: [String: (NIOSSLCertificate, NIOSSLPrivateKey)] = [:]
private var cacheOrder: [String] = []
private let keychainCAKeyTag = "com.treyt.proxyapp.ca.privatekey"
private let keychainCACertTag = "com.treyt.proxyapp.ca.cert"
private init() {
loadOrGenerateCA()
}
// MARK: - Public API
public var hasCA: Bool {
lock.lock()
defer { lock.unlock() }
return rootCACert != nil
}
/// Get or generate a leaf certificate + key for the given domain.
public func tlsServerContext(for domain: String) throws -> NIOSSLContext {
lock.lock()
defer { lock.unlock() }
if let cached = certCache[domain] {
cacheOrder.removeAll { $0 == domain }
cacheOrder.append(domain)
return try makeServerContext(cert: cached.0, key: cached.1)
}
guard let caKey = rootCAKey, let caCert = rootCACert else {
throw CertificateError.caNotFound
}
let (leafCert, leafKey) = try generateLeaf(domain: domain, caKey: caKey, caCert: caCert)
// Serialize to DER/PEM for NIOSSL
var serializer = DER.Serializer()
try leafCert.serialize(into: &serializer)
let leafDER = serializer.serializedBytes
let nioLeafCert = try NIOSSLCertificate(bytes: leafDER, format: .der)
let leafKeyPEM = leafKey.pemRepresentation
let nioLeafKey = try NIOSSLPrivateKey(bytes: [UInt8](leafKeyPEM.utf8), format: .pem)
certCache[domain] = (nioLeafCert, nioLeafKey)
cacheOrder.append(domain)
while cacheOrder.count > ProxyConstants.certificateCacheSize {
let evicted = cacheOrder.removeFirst()
certCache.removeValue(forKey: evicted)
}
return try makeServerContext(cert: nioLeafCert, key: nioLeafKey)
}
/// Export the root CA as DER data for user installation.
public func exportCACertificateDER() -> [UInt8]? {
lock.lock()
defer { lock.unlock() }
guard let cert = rootCACert else { return nil }
var serializer = DER.Serializer()
try? cert.serialize(into: &serializer)
return serializer.serializedBytes
}
/// Export as PEM for display.
public func exportCACertificatePEM() -> String? {
lock.lock()
defer { lock.unlock() }
guard let cert = rootCACert else { return nil }
guard let pem = try? cert.serializeAsPEM() else { return nil }
return pem.pemString
}
public var caNotValidAfter: Date? {
lock.lock()
defer { lock.unlock() }
// notValidAfter is a Time, not directly a Date we stored the date when generating
return nil // Will be set properly after we store dates
}
// MARK: - CA Generation
private func loadOrGenerateCA() {
if loadCAFromKeychain() { return }
do {
let key = P256.Signing.PrivateKey()
let name = try DistinguishedName {
CommonName("Proxy CA (\(deviceName()))")
OrganizationName("ProxyApp")
}
let now = Date()
let twoYearsLater = now.addingTimeInterval(365 * 24 * 3600 * 2)
let extensions = try Certificate.Extensions {
Critical(BasicConstraints.isCertificateAuthority(maxPathLength: 0))
Critical(KeyUsage(keyCertSign: true, cRLSign: true))
}
let cert = try Certificate(
version: .v3,
serialNumber: Certificate.SerialNumber(),
publicKey: .init(key.publicKey),
notValidBefore: now,
notValidAfter: twoYearsLater,
issuer: name,
subject: name,
signatureAlgorithm: .ecdsaWithSHA256,
extensions: extensions,
issuerPrivateKey: .init(key)
)
self.rootCAKey = key
self.rootCACert = cert
var serializer = DER.Serializer()
try cert.serialize(into: &serializer)
let der = serializer.serializedBytes
self.rootCANIOSSL = try NIOSSLCertificate(bytes: der, format: .der)
saveCAToKeychain(key: key, certDER: der)
print("[CertificateManager] Generated new root CA")
} catch {
print("[CertificateManager] Failed to generate CA: \(error)")
}
}
// MARK: - Leaf Certificate Generation
private func generateLeaf(
domain: String,
caKey: P256.Signing.PrivateKey,
caCert: Certificate
) throws -> (Certificate, P256.Signing.PrivateKey) {
let leafKey = P256.Signing.PrivateKey()
let now = Date()
let oneYearLater = now.addingTimeInterval(365 * 24 * 3600)
let extensions = try Certificate.Extensions {
Critical(BasicConstraints.notCertificateAuthority)
Critical(KeyUsage(digitalSignature: true))
try ExtendedKeyUsage([.serverAuth])
SubjectAlternativeNames([.dnsName(domain)])
}
let leafName = try DistinguishedName {
CommonName(domain)
OrganizationName("ProxyApp")
}
let cert = try Certificate(
version: .v3,
serialNumber: Certificate.SerialNumber(),
publicKey: .init(leafKey.publicKey),
notValidBefore: now,
notValidAfter: oneYearLater,
issuer: caCert.subject,
subject: leafName,
signatureAlgorithm: .ecdsaWithSHA256,
extensions: extensions,
issuerPrivateKey: .init(caKey)
)
return (cert, leafKey)
}
// MARK: - TLS Context
private func makeServerContext(cert: NIOSSLCertificate, key: NIOSSLPrivateKey) throws -> NIOSSLContext {
var certs = [cert]
if let caCert = rootCANIOSSL {
certs.append(caCert)
}
var config = TLSConfiguration.makeServerConfiguration(
certificateChain: certs.map { .certificate($0) },
privateKey: .privateKey(key)
)
config.applicationProtocols = ["http/1.1"]
return try NIOSSLContext(configuration: config)
}
// MARK: - Keychain
private func loadCAFromKeychain() -> Bool {
let keyQuery: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: keychainCAKeyTag,
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
kSecReturnData as String: true
]
var keyResult: AnyObject?
guard SecItemCopyMatching(keyQuery as CFDictionary, &keyResult) == errSecSuccess,
let keyData = keyResult as? Data else { return false }
let certQuery: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: keychainCACertTag,
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
kSecReturnData as String: true
]
var certResult: AnyObject?
guard SecItemCopyMatching(certQuery as CFDictionary, &certResult) == errSecSuccess,
let certData = certResult as? Data else { return false }
do {
let key = try P256.Signing.PrivateKey(rawRepresentation: keyData)
let cert = try Certificate(derEncoded: [UInt8](certData))
let nioCert = try NIOSSLCertificate(bytes: [UInt8](certData), format: .der)
self.rootCAKey = key
self.rootCACert = cert
self.rootCANIOSSL = nioCert
print("[CertificateManager] Loaded CA from Keychain")
return true
} catch {
print("[CertificateManager] Failed to load CA from Keychain: \(error)")
return false
}
}
private func saveCAToKeychain(key: P256.Signing.PrivateKey, certDER: [UInt8]) {
let keyData = key.rawRepresentation
// Delete existing entries
for tag in [keychainCAKeyTag, keychainCACertTag] {
let deleteQuery: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: tag,
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier
]
SecItemDelete(deleteQuery as CFDictionary)
}
// Save key
let addKeyQuery: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: keychainCAKeyTag,
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
kSecValueData as String: keyData,
kSecAttrAccessible as String: kSecAttrAccessibleAfterFirstUnlock
]
SecItemAdd(addKeyQuery as CFDictionary, nil)
// Save cert
let addCertQuery: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: keychainCACertTag,
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
kSecValueData as String: Data(certDER),
kSecAttrAccessible as String: kSecAttrAccessibleAfterFirstUnlock
]
SecItemAdd(addCertQuery as CFDictionary, nil)
}
// MARK: - Helpers
private func deviceName() -> String {
#if canImport(UIKit)
return UIDevice.current.name
#else
return Host.current().localizedName ?? "Unknown"
#endif
}
public enum CertificateError: Error {
case notImplemented
case generationFailed
case caNotFound
}
}

View File

@@ -0,0 +1,298 @@
import Foundation
import NIOCore
import NIOPosix
import NIOHTTP1
/// Handles incoming proxy requests:
/// - HTTP CONNECT establishes TCP tunnel (GlueHandler passthrough, or MITM in Phase 3)
/// - Plain HTTP connects upstream, forwards request, captures request+response
final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundOut = HTTPServerResponsePart
private let trafficRepo: TrafficRepository
// Buffer request parts until we've connected upstream
private var pendingHead: HTTPRequestHead?
private var pendingBody: [ByteBuffer] = []
private var pendingEnd: HTTPHeaders?
private var receivedEnd = false
init(trafficRepo: TrafficRepository) {
self.trafficRepo = trafficRepo
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = unwrapInboundIn(data)
switch part {
case .head(let head):
if head.method == .CONNECT {
handleConnect(context: context, head: head)
} else {
pendingHead = head
}
case .body(let buffer):
pendingBody.append(buffer)
case .end(let trailers):
if pendingHead != nil {
pendingEnd = trailers
receivedEnd = true
handleHTTPRequest(context: context)
}
}
}
// MARK: - CONNECT (HTTPS tunnel)
private func handleConnect(context: ChannelHandlerContext, head: HTTPRequestHead) {
let components = head.uri.split(separator: ":")
let host = String(components[0])
let port = components.count > 1 ? Int(components[1]) ?? 443 : 443
// Check if this domain should be MITM'd (SSL Proxying enabled + domain in include list)
let shouldMITM = shouldInterceptSSL(domain: host)
// Send 200 Connection Established
let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
if shouldMITM {
// MITM mode: strip HTTP handlers, install MITMHandler
setupMITM(context: context, host: host, port: port)
} else {
// Passthrough mode: record domain-level entry, tunnel raw bytes
recordConnectTraffic(host: host, port: port)
// We don't need to connect upstream ourselves GlueHandler does raw forwarding
// But GlueHandler pairs two channels, so we need the remote channel first
ClientBootstrap(group: context.eventLoop)
.channelOption(.socketOption(.so_reuseaddr), value: 1)
.connect(host: host, port: port)
.whenComplete { result in
switch result {
case .success(let remoteChannel):
self.setupGlue(context: context, remoteChannel: remoteChannel)
case .failure(let error):
print("[Proxy] CONNECT passthrough failed to \(host):\(port): \(error)")
context.close(promise: nil)
}
}
}
}
private func shouldInterceptSSL(domain: String) -> Bool {
guard IPCManager.shared.isSSLProxyingEnabled else { return false }
guard CertificateManager.shared.hasCA else { return false }
// Check SSL proxying list from database
let rulesRepo = RulesRepository()
do {
let entries = try rulesRepo.fetchAllSSLEntries()
// Check exclude list first
for entry in entries where !entry.isInclude {
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
return false
}
}
// Check include list
for entry in entries where entry.isInclude {
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
return true
}
}
} catch {
print("[Proxy] Failed to check SSL proxying list: \(error)")
}
return false
}
private func setupMITM(context: ChannelHandlerContext, host: String, port: Int) {
let mitmHandler = MITMHandler(host: host, port: port, trafficRepo: trafficRepo)
// Remove HTTP handlers, keep raw bytes for MITMHandler
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
.whenSuccess { handler in
context.channel.pipeline.removeHandler(handler, promise: nil)
}
context.pipeline.removeHandler(context: context).whenComplete { _ in
context.channel.pipeline.addHandler(mitmHandler).whenFailure { error in
print("[Proxy] Failed to install MITM handler: \(error)")
context.close(promise: nil)
}
}
}
private func setupGlue(context: ChannelHandlerContext, remoteChannel: Channel) {
let localGlue = GlueHandler()
let remoteGlue = GlueHandler()
localGlue.partner = remoteGlue
remoteGlue.partner = localGlue
// Remove all HTTP handlers from the client channel, leaving raw bytes
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
.whenSuccess { handler in
context.channel.pipeline.removeHandler(handler, promise: nil)
}
context.pipeline.removeHandler(context: context).whenComplete { _ in
context.channel.pipeline.addHandler(localGlue).whenSuccess {
remoteChannel.pipeline.addHandler(remoteGlue).whenFailure { _ in
context.close(promise: nil)
remoteChannel.close(promise: nil)
}
}
}
}
// MARK: - Plain HTTP forwarding
private func handleHTTPRequest(context: ChannelHandlerContext) {
guard let head = pendingHead else { return }
// Parse host and port from the absolute URI or Host header
guard let (host, port, path) = parseHTTPTarget(head: head) else {
let responseHead = HTTPResponseHead(version: .http1_1, status: .badRequest)
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
return
}
// Rewrite the request URI to relative path (upstream expects /path, not http://host/path)
var upstreamHead = head
upstreamHead.uri = path
// Ensure Host header is set
if !upstreamHead.headers.contains(name: "Host") {
upstreamHead.headers.add(name: "Host", value: host)
}
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: host, scheme: "http")
ClientBootstrap(group: context.eventLoop)
.channelOption(.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
// Remote channel: decode HTTP responses, encode HTTP requests
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
}.flatMap {
channel.pipeline.addHandler(captureHandler)
}.flatMap {
channel.pipeline.addHandler(
HTTPRelayHandler(clientContext: context, wrapResponse: self.wrapOutboundOut)
)
}
}
.connect(host: host, port: port)
.whenComplete { result in
switch result {
case .success(let remoteChannel):
// Forward the buffered request to upstream
remoteChannel.write(NIOAny(HTTPClientRequestPart.head(upstreamHead)), promise: nil)
for bodyBuffer in self.pendingBody {
remoteChannel.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(bodyBuffer))), promise: nil)
}
remoteChannel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(self.pendingEnd)), promise: nil)
// Clear buffered data
self.pendingHead = nil
self.pendingBody.removeAll()
self.pendingEnd = nil
case .failure(let error):
print("[Proxy] HTTP forward failed to \(host):\(port): \(error)")
let responseHead = HTTPResponseHead(version: .http1_1, status: .badGateway)
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}
}
}
// MARK: - URL Parsing
private func parseHTTPTarget(head: HTTPRequestHead) -> (host: String, port: Int, path: String)? {
// Absolute URI: "http://example.com:8080/path?query"
if head.uri.hasPrefix("http://") || head.uri.hasPrefix("https://") {
guard let url = URLComponents(string: head.uri) else { return nil }
let host = url.host ?? ""
let port = url.port ?? (head.uri.hasPrefix("https") ? 443 : 80)
var path = url.path.isEmpty ? "/" : url.path
if let query = url.query {
path += "?\(query)"
}
return (host, port, path)
}
// Relative URI with Host header
if let hostHeader = head.headers.first(name: "Host") {
let parts = hostHeader.split(separator: ":")
let host = String(parts[0])
let port = parts.count > 1 ? Int(parts[1]) ?? 80 : 80
return (host, port, head.uri)
}
return nil
}
// MARK: - CONNECT traffic recording
private func recordConnectTraffic(host: String, port: Int) {
var traffic = CapturedTraffic(
domain: host,
url: "https://\(host):\(port)",
method: "CONNECT",
scheme: "https",
statusCode: 200,
statusText: "Connection Established",
startedAt: Date().timeIntervalSince1970,
completedAt: Date().timeIntervalSince1970,
durationMs: 0,
isSslDecrypted: false
)
try? trafficRepo.insert(&traffic)
IPCManager.shared.post(.newTrafficCaptured)
}
}
// MARK: - HTTPRelayHandler
/// Relays HTTP responses from the upstream server back to the proxy client.
final class HTTPRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
private let clientContext: ChannelHandlerContext
private let wrapResponse: (HTTPServerResponsePart) -> NIOAny
init(clientContext: ChannelHandlerContext, wrapResponse: @escaping (HTTPServerResponsePart) -> NIOAny) {
self.clientContext = clientContext
self.wrapResponse = wrapResponse
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = unwrapInboundIn(data)
switch part {
case .head(let head):
let serverHead = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
clientContext.write(wrapResponse(.head(serverHead)), promise: nil)
case .body(let buffer):
clientContext.write(wrapResponse(.body(.byteBuffer(buffer))), promise: nil)
case .end(let trailers):
clientContext.writeAndFlush(wrapResponse(.end(trailers)), promise: nil)
}
}
func channelInactive(context: ChannelHandlerContext) {
clientContext.close(promise: nil)
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
print("[Proxy] Relay error: \(error)")
context.close(promise: nil)
clientContext.close(promise: nil)
}
}

View File

@@ -0,0 +1,66 @@
import Foundation
import NIOCore
/// Bidirectional TCP forwarder. Pairs two channels so bytes flow in both directions.
/// Used for CONNECT tunneling (passthrough mode, no MITM).
final class GlueHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
typealias OutboundOut = ByteBuffer
var partner: GlueHandler?
private var context: ChannelHandlerContext?
private var pendingRead = false
func handlerAdded(context: ChannelHandlerContext) {
self.context = context
}
func handlerRemoved(context: ChannelHandlerContext) {
self.context = nil
self.partner = nil
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
partner?.write(unwrapInboundIn(data))
}
func channelReadComplete(context: ChannelHandlerContext) {
partner?.flush()
}
func channelInactive(context: ChannelHandlerContext) {
partner?.close()
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
context.close(promise: nil)
}
func channelWritabilityChanged(context: ChannelHandlerContext) {
if context.channel.isWritable {
partner?.read()
}
}
// MARK: - Partner operations
private func write(_ buffer: ByteBuffer) {
context?.write(wrapOutboundOut(buffer), promise: nil)
}
private func flush() {
context?.flush()
}
private func read() {
if let context, !pendingRead {
pendingRead = true
context.read()
pendingRead = false
}
}
private func close() {
context?.close(promise: nil)
}
}

View File

@@ -0,0 +1,141 @@
import Foundation
import NIOCore
import NIOHTTP1
/// Captures HTTP request/response pairs and writes them to the traffic database.
/// Inserted into the pipeline after TLS termination (MITM) or for plain HTTP.
final class HTTPCaptureHandler: ChannelDuplexHandler {
typealias InboundIn = HTTPClientResponsePart
typealias InboundOut = HTTPClientResponsePart
typealias OutboundIn = HTTPClientRequestPart
typealias OutboundOut = HTTPClientRequestPart
private let trafficRepo: TrafficRepository
private let domain: String
private let scheme: String
private var currentRequestId: String?
private var requestHead: HTTPRequestHead?
private var requestBody = Data()
private var responseHead: HTTPResponseHead?
private var responseBody = Data()
private var requestStartTime: Double = 0
init(trafficRepo: TrafficRepository, domain: String, scheme: String = "https") {
self.trafficRepo = trafficRepo
self.domain = domain
self.scheme = scheme
}
// MARK: - Outbound (Request)
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let part = unwrapOutboundIn(data)
switch part {
case .head(let head):
currentRequestId = UUID().uuidString
requestHead = head
requestBody = Data()
requestStartTime = Date().timeIntervalSince1970
case .body(.byteBuffer(let buffer)):
if requestBody.count < ProxyConstants.maxBodySizeBytes {
requestBody.append(contentsOf: buffer.readableBytesView)
}
case .end:
saveRequest()
default:
break
}
context.write(data, promise: promise)
}
// MARK: - Inbound (Response)
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = unwrapInboundIn(data)
switch part {
case .head(let head):
responseHead = head
responseBody = Data()
case .body(let buffer):
if responseBody.count < ProxyConstants.maxBodySizeBytes {
responseBody.append(contentsOf: buffer.readableBytesView)
}
case .end:
saveResponse()
}
context.fireChannelRead(data)
}
// MARK: - Persistence
private func saveRequest() {
guard let head = requestHead, let reqId = currentRequestId else { return }
let url = "\(scheme)://\(domain)\(head.uri)"
let headersJSON = encodeHeaders(head.headers)
let queryParams = extractQueryParams(from: head.uri)
var traffic = CapturedTraffic(
requestId: reqId,
domain: domain,
url: url,
method: head.method.rawValue,
scheme: scheme,
requestHeaders: headersJSON,
requestBody: requestBody.isEmpty ? nil : requestBody,
requestBodySize: requestBody.count,
requestContentType: head.headers.first(name: "Content-Type"),
queryParameters: queryParams,
startedAt: requestStartTime,
isSslDecrypted: scheme == "https"
)
try? trafficRepo.insert(&traffic)
}
private func saveResponse() {
guard let reqId = currentRequestId, let head = responseHead else { return }
let now = Date().timeIntervalSince1970
let durationMs = Int((now - requestStartTime) * 1000)
try? trafficRepo.updateResponse(
requestId: reqId,
statusCode: Int(head.status.code),
statusText: head.status.reasonPhrase,
responseHeaders: encodeHeaders(head.headers),
responseBody: responseBody.isEmpty ? nil : responseBody,
responseBodySize: responseBody.count,
responseContentType: head.headers.first(name: "Content-Type"),
completedAt: now,
durationMs: durationMs
)
IPCManager.shared.post(.newTrafficCaptured)
}
private func encodeHeaders(_ headers: HTTPHeaders) -> String? {
var dict: [String: String] = [:]
for (name, value) in headers {
dict[name] = value
}
guard let data = try? JSONEncoder().encode(dict) else { return nil }
return String(data: data, encoding: .utf8)
}
private func extractQueryParams(from uri: String) -> String? {
guard let url = URLComponents(string: uri),
let items = url.queryItems, !items.isEmpty else { return nil }
var dict: [String: String] = [:]
for item in items {
dict[item.name] = item.value ?? ""
}
guard let data = try? JSONEncoder().encode(dict) else { return nil }
return String(data: data, encoding: .utf8)
}
}

View File

@@ -0,0 +1,294 @@
import Foundation
import NIOCore
import NIOPosix
import NIOSSL
import NIOHTTP1
/// After a CONNECT tunnel is established, this handler:
/// 1. Reads the first bytes from the client to extract the SNI hostname from the TLS ClientHello
/// 2. Generates a per-domain leaf certificate via CertificateManager
/// 3. Terminates client-side TLS with the generated cert
/// 4. Initiates server-side TLS to the real server
/// 5. Installs HTTP codecs + HTTPCaptureHandler on both sides to capture decrypted traffic
final class MITMHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
private let host: String
private let port: Int
private let trafficRepo: TrafficRepository
private let certManager: CertificateManager
init(host: String, port: Int, trafficRepo: TrafficRepository, certManager: CertificateManager = .shared) {
self.host = host
self.port = port
self.trafficRepo = trafficRepo
self.certManager = certManager
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var buffer = unwrapInboundIn(data)
// Extract SNI from ClientHello if possible, otherwise use the CONNECT host
let sniDomain = extractSNI(from: buffer) ?? host
// Remove this handler we'll rebuild the pipeline
context.pipeline.removeHandler(self, promise: nil)
// Get TLS context for this domain
let sslContext: NIOSSLContext
do {
sslContext = try certManager.tlsServerContext(for: sniDomain)
} catch {
print("[MITM] Failed to get TLS context for \(sniDomain): \(error)")
context.close(promise: nil)
return
}
// Add server-side TLS handler (we are the "server" to the client)
let sslServerHandler = NIOSSLServerHandler(context: sslContext)
let trafficRepo = self.trafficRepo
let host = self.host
let port = self.port
context.channel.pipeline.addHandler(sslServerHandler, position: .first).flatMap {
// Add HTTP codec after TLS
context.channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))
}.flatMap {
context.channel.pipeline.addHandler(HTTPResponseEncoder())
}.flatMap {
// Add the forwarding handler that connects to the real server
context.channel.pipeline.addHandler(
MITMForwardHandler(
remoteHost: host,
remotePort: port,
domain: sniDomain,
trafficRepo: trafficRepo
)
)
}.whenComplete { result in
switch result {
case .success:
// Re-fire the original ClientHello bytes so TLS handshake proceeds
context.channel.pipeline.fireChannelRead(NIOAny(buffer))
case .failure(let error):
print("[MITM] Pipeline setup failed: \(error)")
context.close(promise: nil)
}
}
}
// MARK: - SNI Extraction
/// Parse the SNI hostname from a TLS ClientHello message.
private func extractSNI(from buffer: ByteBuffer) -> String? {
var buf = buffer
guard buf.readableBytes >= 43 else { return nil }
// TLS record header
guard buf.readInteger(as: UInt8.self) == 0x16 else { return nil } // Handshake
let _ = buf.readInteger(as: UInt16.self) // Version
let _ = buf.readInteger(as: UInt16.self) // Length
// Handshake header
guard buf.readInteger(as: UInt8.self) == 0x01 else { return nil } // ClientHello
let _ = buf.readBytes(length: 3) // Length (3 bytes)
// Client version
let _ = buf.readInteger(as: UInt16.self)
// Random (32 bytes)
guard buf.readBytes(length: 32) != nil else { return nil }
// Session ID
guard let sessionIdLen = buf.readInteger(as: UInt8.self) else { return nil }
guard buf.readBytes(length: Int(sessionIdLen)) != nil else { return nil }
// Cipher suites
guard let cipherSuitesLen = buf.readInteger(as: UInt16.self) else { return nil }
guard buf.readBytes(length: Int(cipherSuitesLen)) != nil else { return nil }
// Compression methods
guard let compMethodsLen = buf.readInteger(as: UInt8.self) else { return nil }
guard buf.readBytes(length: Int(compMethodsLen)) != nil else { return nil }
// Extensions
guard let extensionsLen = buf.readInteger(as: UInt16.self) else { return nil }
var extensionsRemaining = Int(extensionsLen)
while extensionsRemaining > 4 {
guard let extType = buf.readInteger(as: UInt16.self),
let extLen = buf.readInteger(as: UInt16.self) else { return nil }
extensionsRemaining -= 4 + Int(extLen)
if extType == 0x0000 { // SNI extension
guard let _ = buf.readInteger(as: UInt16.self), // SNI list length
let nameType = buf.readInteger(as: UInt8.self),
nameType == 0x00, // hostname
let nameLen = buf.readInteger(as: UInt16.self),
let nameBytes = buf.readBytes(length: Int(nameLen)) else {
return nil
}
return String(bytes: nameBytes, encoding: .utf8)
} else {
guard buf.readBytes(length: Int(extLen)) != nil else { return nil }
}
}
return nil
}
}
// MARK: - MITMForwardHandler
/// Handles decrypted HTTP from the client, forwards to the real server over TLS,
/// and relays responses back. Captures everything via HTTPCaptureHandler.
final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundOut = HTTPServerResponsePart
private let remoteHost: String
private let remotePort: Int
private let domain: String
private let trafficRepo: TrafficRepository
private var remoteChannel: Channel?
// Buffer request parts until upstream is connected
private var pendingParts: [HTTPServerRequestPart] = []
private var isConnected = false
init(remoteHost: String, remotePort: Int, domain: String, trafficRepo: TrafficRepository) {
self.remoteHost = remoteHost
self.remotePort = remotePort
self.domain = domain
self.trafficRepo = trafficRepo
}
func handlerAdded(context: ChannelHandlerContext) {
connectToRemote(context: context)
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = unwrapInboundIn(data)
if isConnected, let remote = remoteChannel {
// Forward to upstream as client request
switch part {
case .head(let head):
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
if !clientHead.headers.contains(name: "Host") {
clientHead.headers.add(name: "Host", value: domain)
}
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
case .body(let buffer):
remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil)
case .end(let trailers):
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
}
} else {
pendingParts.append(part)
}
}
func channelInactive(context: ChannelHandlerContext) {
remoteChannel?.close(promise: nil)
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
print("[MITMForward] Error: \(error)")
context.close(promise: nil)
remoteChannel?.close(promise: nil)
}
private func connectToRemote(context: ChannelHandlerContext) {
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: domain, scheme: "https")
let clientContext = context
do {
let tlsConfig = TLSConfiguration.makeClientConfiguration()
let sslContext = try NIOSSLContext(configuration: tlsConfig)
ClientBootstrap(group: context.eventLoop)
.channelOption(.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
let sniHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.domain)
return channel.pipeline.addHandler(sniHandler).flatMap {
channel.pipeline.addHandler(HTTPRequestEncoder())
}.flatMap {
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))
}.flatMap {
channel.pipeline.addHandler(captureHandler)
}.flatMap {
channel.pipeline.addHandler(
MITMRelayHandler(clientContext: clientContext)
)
}
}
.connect(host: remoteHost, port: remotePort)
.whenComplete { result in
switch result {
case .success(let channel):
self.remoteChannel = channel
self.isConnected = true
self.flushPending(remote: channel)
case .failure(let error):
print("[MITMForward] Connect to \(self.remoteHost):\(self.remotePort) failed: \(error)")
clientContext.close(promise: nil)
}
}
} catch {
print("[MITMForward] TLS setup failed: \(error)")
context.close(promise: nil)
}
}
private func flushPending(remote: Channel) {
for part in pendingParts {
switch part {
case .head(let head):
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
if !clientHead.headers.contains(name: "Host") {
clientHead.headers.add(name: "Host", value: domain)
}
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
case .body(let buffer):
remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil)
case .end(let trailers):
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
}
}
pendingParts.removeAll()
}
}
// MARK: - MITMRelayHandler
/// Relays responses from the real server back to the proxy client.
final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
private let clientContext: ChannelHandlerContext
init(clientContext: ChannelHandlerContext) {
self.clientContext = clientContext
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = unwrapInboundIn(data)
switch part {
case .head(let head):
let serverResponse = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
clientContext.write(NIOAny(HTTPServerResponsePart.head(serverResponse)), promise: nil)
case .body(let buffer):
clientContext.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil)
case .end(let trailers):
clientContext.writeAndFlush(NIOAny(HTTPServerResponsePart.end(trailers)), promise: nil)
}
}
func channelInactive(context: ChannelHandlerContext) {
clientContext.close(promise: nil)
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
print("[MITMRelay] Error: \(error)")
context.close(promise: nil)
clientContext.close(promise: nil)
}
}

View File

@@ -0,0 +1,55 @@
import Foundation
import NIOCore
import NIOPosix
import NIOHTTP1
public final class ProxyServer: Sendable {
private let host: String
private let port: Int
private let group: EventLoopGroup
private let trafficRepo: TrafficRepository
nonisolated(unsafe) private var channel: Channel?
public init(
host: String = ProxyConstants.proxyHost,
port: Int = ProxyConstants.proxyPort,
trafficRepo: TrafficRepository = TrafficRepository()
) {
self.host = host
self.port = port
// Use only 1 thread to conserve memory in the extension (50MB budget)
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
self.trafficRepo = trafficRepo
}
public func start() async throws {
let trafficRepo = self.trafficRepo
let bootstrap = ServerBootstrap(group: group)
.serverChannelOption(.backlog, value: 256)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.childChannelInitializer { channel in
channel.pipeline.addHandler(
ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
).flatMap {
channel.pipeline.addHandler(HTTPResponseEncoder())
}.flatMap {
channel.pipeline.addHandler(ConnectHandler(trafficRepo: trafficRepo))
}
}
.childChannelOption(.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(.maxMessagesPerRead, value: 16)
channel = try await bootstrap.bind(host: host, port: port).get()
print("[ProxyServer] Listening on \(host):\(port)")
}
public func stop() async {
do {
try await channel?.close()
try await group.shutdownGracefully()
} catch {
print("[ProxyServer] Shutdown error: \(error)")
}
}
}