import Foundation import NIOCore import NIOPosix import NIOSSL import NIOHTTP1 /// After a CONNECT tunnel is established, this handler: /// 1. Reads the first bytes from the client to extract the SNI hostname from the TLS ClientHello /// 2. Generates a per-domain leaf certificate via CertificateManager /// 3. Terminates client-side TLS with the generated cert /// 4. Initiates server-side TLS to the real server /// 5. Installs HTTP codecs + HTTPCaptureHandler on both sides to capture decrypted traffic final class MITMHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = ByteBuffer private let host: String private let port: Int private let trafficRepo: TrafficRepository private let certManager: CertificateManager init(host: String, port: Int, trafficRepo: TrafficRepository, certManager: CertificateManager = .shared) { self.host = host self.port = port self.trafficRepo = trafficRepo self.certManager = certManager } func channelRead(context: ChannelHandlerContext, data: NIOAny) { var buffer = unwrapInboundIn(data) // Extract SNI from ClientHello if possible, otherwise use the CONNECT host let sniDomain = extractSNI(from: buffer) ?? host // Remove this handler — we'll rebuild the pipeline context.pipeline.removeHandler(self, promise: nil) // Get TLS context for this domain let sslContext: NIOSSLContext do { sslContext = try certManager.tlsServerContext(for: sniDomain) } catch { print("[MITM] Failed to get TLS context for \(sniDomain): \(error)") context.close(promise: nil) return } // Add server-side TLS handler (we are the "server" to the client) let sslServerHandler = NIOSSLServerHandler(context: sslContext) let trafficRepo = self.trafficRepo let host = self.host let port = self.port context.channel.pipeline.addHandler(sslServerHandler, position: .first).flatMap { // Add HTTP codec after TLS context.channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder())) }.flatMap { context.channel.pipeline.addHandler(HTTPResponseEncoder()) }.flatMap { // Add the forwarding handler that connects to the real server context.channel.pipeline.addHandler( MITMForwardHandler( remoteHost: host, remotePort: port, domain: sniDomain, trafficRepo: trafficRepo ) ) }.whenComplete { result in switch result { case .success: // Re-fire the original ClientHello bytes so TLS handshake proceeds context.channel.pipeline.fireChannelRead(NIOAny(buffer)) case .failure(let error): print("[MITM] Pipeline setup failed: \(error)") context.close(promise: nil) } } } // MARK: - SNI Extraction /// Parse the SNI hostname from a TLS ClientHello message. private func extractSNI(from buffer: ByteBuffer) -> String? { var buf = buffer guard buf.readableBytes >= 43 else { return nil } // TLS record header guard buf.readInteger(as: UInt8.self) == 0x16 else { return nil } // Handshake let _ = buf.readInteger(as: UInt16.self) // Version let _ = buf.readInteger(as: UInt16.self) // Length // Handshake header guard buf.readInteger(as: UInt8.self) == 0x01 else { return nil } // ClientHello let _ = buf.readBytes(length: 3) // Length (3 bytes) // Client version let _ = buf.readInteger(as: UInt16.self) // Random (32 bytes) guard buf.readBytes(length: 32) != nil else { return nil } // Session ID guard let sessionIdLen = buf.readInteger(as: UInt8.self) else { return nil } guard buf.readBytes(length: Int(sessionIdLen)) != nil else { return nil } // Cipher suites guard let cipherSuitesLen = buf.readInteger(as: UInt16.self) else { return nil } guard buf.readBytes(length: Int(cipherSuitesLen)) != nil else { return nil } // Compression methods guard let compMethodsLen = buf.readInteger(as: UInt8.self) else { return nil } guard buf.readBytes(length: Int(compMethodsLen)) != nil else { return nil } // Extensions guard let extensionsLen = buf.readInteger(as: UInt16.self) else { return nil } var extensionsRemaining = Int(extensionsLen) while extensionsRemaining > 4 { guard let extType = buf.readInteger(as: UInt16.self), let extLen = buf.readInteger(as: UInt16.self) else { return nil } extensionsRemaining -= 4 + Int(extLen) if extType == 0x0000 { // SNI extension guard let _ = buf.readInteger(as: UInt16.self), // SNI list length let nameType = buf.readInteger(as: UInt8.self), nameType == 0x00, // hostname let nameLen = buf.readInteger(as: UInt16.self), let nameBytes = buf.readBytes(length: Int(nameLen)) else { return nil } return String(bytes: nameBytes, encoding: .utf8) } else { guard buf.readBytes(length: Int(extLen)) != nil else { return nil } } } return nil } } // MARK: - MITMForwardHandler /// Handles decrypted HTTP from the client, forwards to the real server over TLS, /// and relays responses back. Captures everything via HTTPCaptureHandler. final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart private let remoteHost: String private let remotePort: Int private let domain: String private let trafficRepo: TrafficRepository private var remoteChannel: Channel? // Buffer request parts until upstream is connected private var pendingParts: [HTTPServerRequestPart] = [] private var isConnected = false init(remoteHost: String, remotePort: Int, domain: String, trafficRepo: TrafficRepository) { self.remoteHost = remoteHost self.remotePort = remotePort self.domain = domain self.trafficRepo = trafficRepo } func handlerAdded(context: ChannelHandlerContext) { connectToRemote(context: context) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let part = unwrapInboundIn(data) if isConnected, let remote = remoteChannel { // Forward to upstream as client request switch part { case .head(let head): var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers) if !clientHead.headers.contains(name: "Host") { clientHead.headers.add(name: "Host", value: domain) } remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil) case .body(let buffer): remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil) case .end(let trailers): remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil) } } else { pendingParts.append(part) } } func channelInactive(context: ChannelHandlerContext) { remoteChannel?.close(promise: nil) } func errorCaught(context: ChannelHandlerContext, error: Error) { print("[MITMForward] Error: \(error)") context.close(promise: nil) remoteChannel?.close(promise: nil) } private func connectToRemote(context: ChannelHandlerContext) { let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: domain, scheme: "https") let clientContext = context do { let tlsConfig = TLSConfiguration.makeClientConfiguration() let sslContext = try NIOSSLContext(configuration: tlsConfig) ClientBootstrap(group: context.eventLoop) .channelOption(.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in let sniHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.domain) return channel.pipeline.addHandler(sniHandler).flatMap { channel.pipeline.addHandler(HTTPRequestEncoder()) }.flatMap { channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder())) }.flatMap { channel.pipeline.addHandler(captureHandler) }.flatMap { channel.pipeline.addHandler( MITMRelayHandler(clientContext: clientContext) ) } } .connect(host: remoteHost, port: remotePort) .whenComplete { result in switch result { case .success(let channel): self.remoteChannel = channel self.isConnected = true self.flushPending(remote: channel) case .failure(let error): print("[MITMForward] Connect to \(self.remoteHost):\(self.remotePort) failed: \(error)") clientContext.close(promise: nil) } } } catch { print("[MITMForward] TLS setup failed: \(error)") context.close(promise: nil) } } private func flushPending(remote: Channel) { for part in pendingParts { switch part { case .head(let head): var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers) if !clientHead.headers.contains(name: "Host") { clientHead.headers.add(name: "Host", value: domain) } remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil) case .body(let buffer): remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil) case .end(let trailers): remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil) } } pendingParts.removeAll() } } // MARK: - MITMRelayHandler /// Relays responses from the real server back to the proxy client. final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPClientResponsePart private let clientContext: ChannelHandlerContext init(clientContext: ChannelHandlerContext) { self.clientContext = clientContext } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let part = unwrapInboundIn(data) switch part { case .head(let head): let serverResponse = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers) clientContext.write(NIOAny(HTTPServerResponsePart.head(serverResponse)), promise: nil) case .body(let buffer): clientContext.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil) case .end(let trailers): clientContext.writeAndFlush(NIOAny(HTTPServerResponsePart.end(trailers)), promise: nil) } } func channelInactive(context: ChannelHandlerContext) { clientContext.close(promise: nil) } func errorCaught(context: ChannelHandlerContext, error: Error) { print("[MITMRelay] Error: \(error)") context.close(promise: nil) clientContext.close(promise: nil) } }