295 lines
12 KiB
Swift
295 lines
12 KiB
Swift
import Foundation
|
|
import NIOCore
|
|
import NIOPosix
|
|
import NIOSSL
|
|
import NIOHTTP1
|
|
|
|
/// After a CONNECT tunnel is established, this handler:
|
|
/// 1. Reads the first bytes from the client to extract the SNI hostname from the TLS ClientHello
|
|
/// 2. Generates a per-domain leaf certificate via CertificateManager
|
|
/// 3. Terminates client-side TLS with the generated cert
|
|
/// 4. Initiates server-side TLS to the real server
|
|
/// 5. Installs HTTP codecs + HTTPCaptureHandler on both sides to capture decrypted traffic
|
|
final class MITMHandler: ChannelInboundHandler, RemovableChannelHandler {
|
|
typealias InboundIn = ByteBuffer
|
|
|
|
private let host: String
|
|
private let port: Int
|
|
private let trafficRepo: TrafficRepository
|
|
private let certManager: CertificateManager
|
|
|
|
init(host: String, port: Int, trafficRepo: TrafficRepository, certManager: CertificateManager = .shared) {
|
|
self.host = host
|
|
self.port = port
|
|
self.trafficRepo = trafficRepo
|
|
self.certManager = certManager
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
var buffer = unwrapInboundIn(data)
|
|
|
|
// Extract SNI from ClientHello if possible, otherwise use the CONNECT host
|
|
let sniDomain = extractSNI(from: buffer) ?? host
|
|
|
|
// Remove this handler — we'll rebuild the pipeline
|
|
context.pipeline.removeHandler(self, promise: nil)
|
|
|
|
// Get TLS context for this domain
|
|
let sslContext: NIOSSLContext
|
|
do {
|
|
sslContext = try certManager.tlsServerContext(for: sniDomain)
|
|
} catch {
|
|
print("[MITM] Failed to get TLS context for \(sniDomain): \(error)")
|
|
context.close(promise: nil)
|
|
return
|
|
}
|
|
|
|
// Add server-side TLS handler (we are the "server" to the client)
|
|
let sslServerHandler = NIOSSLServerHandler(context: sslContext)
|
|
let trafficRepo = self.trafficRepo
|
|
let host = self.host
|
|
let port = self.port
|
|
|
|
context.channel.pipeline.addHandler(sslServerHandler, position: .first).flatMap {
|
|
// Add HTTP codec after TLS
|
|
context.channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))
|
|
}.flatMap {
|
|
context.channel.pipeline.addHandler(HTTPResponseEncoder())
|
|
}.flatMap {
|
|
// Add the forwarding handler that connects to the real server
|
|
context.channel.pipeline.addHandler(
|
|
MITMForwardHandler(
|
|
remoteHost: host,
|
|
remotePort: port,
|
|
domain: sniDomain,
|
|
trafficRepo: trafficRepo
|
|
)
|
|
)
|
|
}.whenComplete { result in
|
|
switch result {
|
|
case .success:
|
|
// Re-fire the original ClientHello bytes so TLS handshake proceeds
|
|
context.channel.pipeline.fireChannelRead(NIOAny(buffer))
|
|
case .failure(let error):
|
|
print("[MITM] Pipeline setup failed: \(error)")
|
|
context.close(promise: nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - SNI Extraction
|
|
|
|
/// Parse the SNI hostname from a TLS ClientHello message.
|
|
private func extractSNI(from buffer: ByteBuffer) -> String? {
|
|
var buf = buffer
|
|
guard buf.readableBytes >= 43 else { return nil }
|
|
|
|
// TLS record header
|
|
guard buf.readInteger(as: UInt8.self) == 0x16 else { return nil } // Handshake
|
|
let _ = buf.readInteger(as: UInt16.self) // Version
|
|
let _ = buf.readInteger(as: UInt16.self) // Length
|
|
|
|
// Handshake header
|
|
guard buf.readInteger(as: UInt8.self) == 0x01 else { return nil } // ClientHello
|
|
let _ = buf.readBytes(length: 3) // Length (3 bytes)
|
|
|
|
// Client version
|
|
let _ = buf.readInteger(as: UInt16.self)
|
|
// Random (32 bytes)
|
|
guard buf.readBytes(length: 32) != nil else { return nil }
|
|
// Session ID
|
|
guard let sessionIdLen = buf.readInteger(as: UInt8.self) else { return nil }
|
|
guard buf.readBytes(length: Int(sessionIdLen)) != nil else { return nil }
|
|
// Cipher suites
|
|
guard let cipherSuitesLen = buf.readInteger(as: UInt16.self) else { return nil }
|
|
guard buf.readBytes(length: Int(cipherSuitesLen)) != nil else { return nil }
|
|
// Compression methods
|
|
guard let compMethodsLen = buf.readInteger(as: UInt8.self) else { return nil }
|
|
guard buf.readBytes(length: Int(compMethodsLen)) != nil else { return nil }
|
|
|
|
// Extensions
|
|
guard let extensionsLen = buf.readInteger(as: UInt16.self) else { return nil }
|
|
var extensionsRemaining = Int(extensionsLen)
|
|
|
|
while extensionsRemaining > 4 {
|
|
guard let extType = buf.readInteger(as: UInt16.self),
|
|
let extLen = buf.readInteger(as: UInt16.self) else { return nil }
|
|
extensionsRemaining -= 4 + Int(extLen)
|
|
|
|
if extType == 0x0000 { // SNI extension
|
|
guard let _ = buf.readInteger(as: UInt16.self), // SNI list length
|
|
let nameType = buf.readInteger(as: UInt8.self),
|
|
nameType == 0x00, // hostname
|
|
let nameLen = buf.readInteger(as: UInt16.self),
|
|
let nameBytes = buf.readBytes(length: Int(nameLen)) else {
|
|
return nil
|
|
}
|
|
return String(bytes: nameBytes, encoding: .utf8)
|
|
} else {
|
|
guard buf.readBytes(length: Int(extLen)) != nil else { return nil }
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MARK: - MITMForwardHandler
|
|
|
|
/// Handles decrypted HTTP from the client, forwards to the real server over TLS,
|
|
/// and relays responses back. Captures everything via HTTPCaptureHandler.
|
|
final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
|
typealias InboundIn = HTTPServerRequestPart
|
|
typealias OutboundOut = HTTPServerResponsePart
|
|
|
|
private let remoteHost: String
|
|
private let remotePort: Int
|
|
private let domain: String
|
|
private let trafficRepo: TrafficRepository
|
|
private var remoteChannel: Channel?
|
|
|
|
// Buffer request parts until upstream is connected
|
|
private var pendingParts: [HTTPServerRequestPart] = []
|
|
private var isConnected = false
|
|
|
|
init(remoteHost: String, remotePort: Int, domain: String, trafficRepo: TrafficRepository) {
|
|
self.remoteHost = remoteHost
|
|
self.remotePort = remotePort
|
|
self.domain = domain
|
|
self.trafficRepo = trafficRepo
|
|
}
|
|
|
|
func handlerAdded(context: ChannelHandlerContext) {
|
|
connectToRemote(context: context)
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
let part = unwrapInboundIn(data)
|
|
|
|
if isConnected, let remote = remoteChannel {
|
|
// Forward to upstream as client request
|
|
switch part {
|
|
case .head(let head):
|
|
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
|
|
if !clientHead.headers.contains(name: "Host") {
|
|
clientHead.headers.add(name: "Host", value: domain)
|
|
}
|
|
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
|
|
case .body(let buffer):
|
|
remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil)
|
|
case .end(let trailers):
|
|
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
|
|
}
|
|
} else {
|
|
pendingParts.append(part)
|
|
}
|
|
}
|
|
|
|
func channelInactive(context: ChannelHandlerContext) {
|
|
remoteChannel?.close(promise: nil)
|
|
}
|
|
|
|
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
|
print("[MITMForward] Error: \(error)")
|
|
context.close(promise: nil)
|
|
remoteChannel?.close(promise: nil)
|
|
}
|
|
|
|
private func connectToRemote(context: ChannelHandlerContext) {
|
|
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: domain, scheme: "https")
|
|
let clientContext = context
|
|
|
|
do {
|
|
let tlsConfig = TLSConfiguration.makeClientConfiguration()
|
|
let sslContext = try NIOSSLContext(configuration: tlsConfig)
|
|
|
|
ClientBootstrap(group: context.eventLoop)
|
|
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
|
.channelInitializer { channel in
|
|
let sniHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.domain)
|
|
return channel.pipeline.addHandler(sniHandler).flatMap {
|
|
channel.pipeline.addHandler(HTTPRequestEncoder())
|
|
}.flatMap {
|
|
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))
|
|
}.flatMap {
|
|
channel.pipeline.addHandler(captureHandler)
|
|
}.flatMap {
|
|
channel.pipeline.addHandler(
|
|
MITMRelayHandler(clientContext: clientContext)
|
|
)
|
|
}
|
|
}
|
|
.connect(host: remoteHost, port: remotePort)
|
|
.whenComplete { result in
|
|
switch result {
|
|
case .success(let channel):
|
|
self.remoteChannel = channel
|
|
self.isConnected = true
|
|
self.flushPending(remote: channel)
|
|
case .failure(let error):
|
|
print("[MITMForward] Connect to \(self.remoteHost):\(self.remotePort) failed: \(error)")
|
|
clientContext.close(promise: nil)
|
|
}
|
|
}
|
|
} catch {
|
|
print("[MITMForward] TLS setup failed: \(error)")
|
|
context.close(promise: nil)
|
|
}
|
|
}
|
|
|
|
private func flushPending(remote: Channel) {
|
|
for part in pendingParts {
|
|
switch part {
|
|
case .head(let head):
|
|
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
|
|
if !clientHead.headers.contains(name: "Host") {
|
|
clientHead.headers.add(name: "Host", value: domain)
|
|
}
|
|
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
|
|
case .body(let buffer):
|
|
remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil)
|
|
case .end(let trailers):
|
|
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
|
|
}
|
|
}
|
|
pendingParts.removeAll()
|
|
}
|
|
}
|
|
|
|
// MARK: - MITMRelayHandler
|
|
|
|
/// Relays responses from the real server back to the proxy client.
|
|
final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
|
typealias InboundIn = HTTPClientResponsePart
|
|
|
|
private let clientContext: ChannelHandlerContext
|
|
|
|
init(clientContext: ChannelHandlerContext) {
|
|
self.clientContext = clientContext
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
let part = unwrapInboundIn(data)
|
|
|
|
switch part {
|
|
case .head(let head):
|
|
let serverResponse = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
|
|
clientContext.write(NIOAny(HTTPServerResponsePart.head(serverResponse)), promise: nil)
|
|
case .body(let buffer):
|
|
clientContext.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil)
|
|
case .end(let trailers):
|
|
clientContext.writeAndFlush(NIOAny(HTTPServerResponsePart.end(trailers)), promise: nil)
|
|
}
|
|
}
|
|
|
|
func channelInactive(context: ChannelHandlerContext) {
|
|
clientContext.close(promise: nil)
|
|
}
|
|
|
|
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
|
print("[MITMRelay] Error: \(error)")
|
|
context.close(promise: nil)
|
|
clientContext.close(promise: nil)
|
|
}
|
|
}
|