299 lines
12 KiB
Swift
299 lines
12 KiB
Swift
import Foundation
|
|
import NIOCore
|
|
import NIOPosix
|
|
import NIOHTTP1
|
|
|
|
/// Handles incoming proxy requests:
|
|
/// - HTTP CONNECT → establishes TCP tunnel (GlueHandler passthrough, or MITM in Phase 3)
|
|
/// - Plain HTTP → connects upstream, forwards request, captures request+response
|
|
final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
|
|
typealias InboundIn = HTTPServerRequestPart
|
|
typealias OutboundOut = HTTPServerResponsePart
|
|
|
|
private let trafficRepo: TrafficRepository
|
|
|
|
// Buffer request parts until we've connected upstream
|
|
private var pendingHead: HTTPRequestHead?
|
|
private var pendingBody: [ByteBuffer] = []
|
|
private var pendingEnd: HTTPHeaders?
|
|
private var receivedEnd = false
|
|
|
|
init(trafficRepo: TrafficRepository) {
|
|
self.trafficRepo = trafficRepo
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
let part = unwrapInboundIn(data)
|
|
|
|
switch part {
|
|
case .head(let head):
|
|
if head.method == .CONNECT {
|
|
handleConnect(context: context, head: head)
|
|
} else {
|
|
pendingHead = head
|
|
}
|
|
case .body(let buffer):
|
|
pendingBody.append(buffer)
|
|
case .end(let trailers):
|
|
if pendingHead != nil {
|
|
pendingEnd = trailers
|
|
receivedEnd = true
|
|
handleHTTPRequest(context: context)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - CONNECT (HTTPS tunnel)
|
|
|
|
private func handleConnect(context: ChannelHandlerContext, head: HTTPRequestHead) {
|
|
let components = head.uri.split(separator: ":")
|
|
let host = String(components[0])
|
|
let port = components.count > 1 ? Int(components[1]) ?? 443 : 443
|
|
|
|
// Check if this domain should be MITM'd (SSL Proxying enabled + domain in include list)
|
|
let shouldMITM = shouldInterceptSSL(domain: host)
|
|
|
|
// Send 200 Connection Established
|
|
let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
|
|
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
|
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
|
|
|
if shouldMITM {
|
|
// MITM mode: strip HTTP handlers, install MITMHandler
|
|
setupMITM(context: context, host: host, port: port)
|
|
} else {
|
|
// Passthrough mode: record domain-level entry, tunnel raw bytes
|
|
recordConnectTraffic(host: host, port: port)
|
|
|
|
// We don't need to connect upstream ourselves — GlueHandler does raw forwarding
|
|
// But GlueHandler pairs two channels, so we need the remote channel first
|
|
ClientBootstrap(group: context.eventLoop)
|
|
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
|
.connect(host: host, port: port)
|
|
.whenComplete { result in
|
|
switch result {
|
|
case .success(let remoteChannel):
|
|
self.setupGlue(context: context, remoteChannel: remoteChannel)
|
|
case .failure(let error):
|
|
print("[Proxy] CONNECT passthrough failed to \(host):\(port): \(error)")
|
|
context.close(promise: nil)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private func shouldInterceptSSL(domain: String) -> Bool {
|
|
guard IPCManager.shared.isSSLProxyingEnabled else { return false }
|
|
guard CertificateManager.shared.hasCA else { return false }
|
|
|
|
// Check SSL proxying list from database
|
|
let rulesRepo = RulesRepository()
|
|
do {
|
|
let entries = try rulesRepo.fetchAllSSLEntries()
|
|
|
|
// Check exclude list first
|
|
for entry in entries where !entry.isInclude {
|
|
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Check include list
|
|
for entry in entries where entry.isInclude {
|
|
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
|
return true
|
|
}
|
|
}
|
|
} catch {
|
|
print("[Proxy] Failed to check SSL proxying list: \(error)")
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
private func setupMITM(context: ChannelHandlerContext, host: String, port: Int) {
|
|
let mitmHandler = MITMHandler(host: host, port: port, trafficRepo: trafficRepo)
|
|
|
|
// Remove HTTP handlers, keep raw bytes for MITMHandler
|
|
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
|
|
.whenSuccess { handler in
|
|
context.channel.pipeline.removeHandler(handler, promise: nil)
|
|
}
|
|
|
|
context.pipeline.removeHandler(context: context).whenComplete { _ in
|
|
context.channel.pipeline.addHandler(mitmHandler).whenFailure { error in
|
|
print("[Proxy] Failed to install MITM handler: \(error)")
|
|
context.close(promise: nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func setupGlue(context: ChannelHandlerContext, remoteChannel: Channel) {
|
|
let localGlue = GlueHandler()
|
|
let remoteGlue = GlueHandler()
|
|
localGlue.partner = remoteGlue
|
|
remoteGlue.partner = localGlue
|
|
|
|
// Remove all HTTP handlers from the client channel, leaving raw bytes
|
|
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
|
|
.whenSuccess { handler in
|
|
context.channel.pipeline.removeHandler(handler, promise: nil)
|
|
}
|
|
|
|
context.pipeline.removeHandler(context: context).whenComplete { _ in
|
|
context.channel.pipeline.addHandler(localGlue).whenSuccess {
|
|
remoteChannel.pipeline.addHandler(remoteGlue).whenFailure { _ in
|
|
context.close(promise: nil)
|
|
remoteChannel.close(promise: nil)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - Plain HTTP forwarding
|
|
|
|
private func handleHTTPRequest(context: ChannelHandlerContext) {
|
|
guard let head = pendingHead else { return }
|
|
|
|
// Parse host and port from the absolute URI or Host header
|
|
guard let (host, port, path) = parseHTTPTarget(head: head) else {
|
|
let responseHead = HTTPResponseHead(version: .http1_1, status: .badRequest)
|
|
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
|
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
|
return
|
|
}
|
|
|
|
// Rewrite the request URI to relative path (upstream expects /path, not http://host/path)
|
|
var upstreamHead = head
|
|
upstreamHead.uri = path
|
|
// Ensure Host header is set
|
|
if !upstreamHead.headers.contains(name: "Host") {
|
|
upstreamHead.headers.add(name: "Host", value: host)
|
|
}
|
|
|
|
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: host, scheme: "http")
|
|
|
|
ClientBootstrap(group: context.eventLoop)
|
|
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
|
.channelInitializer { channel in
|
|
// Remote channel: decode HTTP responses, encode HTTP requests
|
|
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
|
|
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
|
|
}.flatMap {
|
|
channel.pipeline.addHandler(captureHandler)
|
|
}.flatMap {
|
|
channel.pipeline.addHandler(
|
|
HTTPRelayHandler(clientContext: context, wrapResponse: self.wrapOutboundOut)
|
|
)
|
|
}
|
|
}
|
|
.connect(host: host, port: port)
|
|
.whenComplete { result in
|
|
switch result {
|
|
case .success(let remoteChannel):
|
|
// Forward the buffered request to upstream
|
|
remoteChannel.write(NIOAny(HTTPClientRequestPart.head(upstreamHead)), promise: nil)
|
|
for bodyBuffer in self.pendingBody {
|
|
remoteChannel.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(bodyBuffer))), promise: nil)
|
|
}
|
|
remoteChannel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(self.pendingEnd)), promise: nil)
|
|
|
|
// Clear buffered data
|
|
self.pendingHead = nil
|
|
self.pendingBody.removeAll()
|
|
self.pendingEnd = nil
|
|
|
|
case .failure(let error):
|
|
print("[Proxy] HTTP forward failed to \(host):\(port): \(error)")
|
|
let responseHead = HTTPResponseHead(version: .http1_1, status: .badGateway)
|
|
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
|
|
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - URL Parsing
|
|
|
|
private func parseHTTPTarget(head: HTTPRequestHead) -> (host: String, port: Int, path: String)? {
|
|
// Absolute URI: "http://example.com:8080/path?query"
|
|
if head.uri.hasPrefix("http://") || head.uri.hasPrefix("https://") {
|
|
guard let url = URLComponents(string: head.uri) else { return nil }
|
|
let host = url.host ?? ""
|
|
let port = url.port ?? (head.uri.hasPrefix("https") ? 443 : 80)
|
|
var path = url.path.isEmpty ? "/" : url.path
|
|
if let query = url.query {
|
|
path += "?\(query)"
|
|
}
|
|
return (host, port, path)
|
|
}
|
|
|
|
// Relative URI with Host header
|
|
if let hostHeader = head.headers.first(name: "Host") {
|
|
let parts = hostHeader.split(separator: ":")
|
|
let host = String(parts[0])
|
|
let port = parts.count > 1 ? Int(parts[1]) ?? 80 : 80
|
|
return (host, port, head.uri)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MARK: - CONNECT traffic recording
|
|
|
|
private func recordConnectTraffic(host: String, port: Int) {
|
|
var traffic = CapturedTraffic(
|
|
domain: host,
|
|
url: "https://\(host):\(port)",
|
|
method: "CONNECT",
|
|
scheme: "https",
|
|
statusCode: 200,
|
|
statusText: "Connection Established",
|
|
startedAt: Date().timeIntervalSince1970,
|
|
completedAt: Date().timeIntervalSince1970,
|
|
durationMs: 0,
|
|
isSslDecrypted: false
|
|
)
|
|
try? trafficRepo.insert(&traffic)
|
|
IPCManager.shared.post(.newTrafficCaptured)
|
|
}
|
|
}
|
|
|
|
// MARK: - HTTPRelayHandler
|
|
|
|
/// Relays HTTP responses from the upstream server back to the proxy client.
|
|
final class HTTPRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
|
typealias InboundIn = HTTPClientResponsePart
|
|
|
|
private let clientContext: ChannelHandlerContext
|
|
private let wrapResponse: (HTTPServerResponsePart) -> NIOAny
|
|
|
|
init(clientContext: ChannelHandlerContext, wrapResponse: @escaping (HTTPServerResponsePart) -> NIOAny) {
|
|
self.clientContext = clientContext
|
|
self.wrapResponse = wrapResponse
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
let part = unwrapInboundIn(data)
|
|
|
|
switch part {
|
|
case .head(let head):
|
|
let serverHead = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
|
|
clientContext.write(wrapResponse(.head(serverHead)), promise: nil)
|
|
case .body(let buffer):
|
|
clientContext.write(wrapResponse(.body(.byteBuffer(buffer))), promise: nil)
|
|
case .end(let trailers):
|
|
clientContext.writeAndFlush(wrapResponse(.end(trailers)), promise: nil)
|
|
}
|
|
}
|
|
|
|
func channelInactive(context: ChannelHandlerContext) {
|
|
clientContext.close(promise: nil)
|
|
}
|
|
|
|
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
|
print("[Proxy] Relay error: \(error)")
|
|
context.close(promise: nil)
|
|
clientContext.close(promise: nil)
|
|
}
|
|
}
|