Add iPad support, auto-pinning, and comprehensive logging
- Adaptive iPhone/iPad layout with NavigationSplitView sidebar - Auto-detect SSL-pinned domains, fall back to passthrough - Certificate install via local HTTP server (Safari profile flow) - App Group-backed CA, per-domain leaf cert LRU cache - DB-backed config repository, Darwin notification throttling - Rules engine, breakpoint rules, pinned domain tracking - os.Logger instrumentation across tunnel/proxy/mitm/capture/cert/rules/db/ipc/ui - Fix dyld framework embed, race conditions, thread safety
This commit is contained in:
@@ -7,7 +7,8 @@ import NIOSSL
|
||||
import UIKit
|
||||
#endif
|
||||
|
||||
/// Manages root CA generation, leaf certificate signing, and an LRU certificate cache.
|
||||
/// Manages the shared MITM root CA. The app owns generation and writes the CA
|
||||
/// into the App Group container; the extension only loads that shared identity.
|
||||
public final class CertificateManager: @unchecked Sendable {
|
||||
public static let shared = CertificateManager()
|
||||
|
||||
@@ -15,16 +16,16 @@ public final class CertificateManager: @unchecked Sendable {
|
||||
private var rootCAKey: P256.Signing.PrivateKey?
|
||||
private var rootCACert: Certificate?
|
||||
private var rootCANIOSSL: NIOSSLCertificate?
|
||||
private var caFingerprintCache: String?
|
||||
private var certificateMTime: Date?
|
||||
private var keyMTime: Date?
|
||||
|
||||
// LRU cache for generated leaf certificates
|
||||
private var certCache: [String: (NIOSSLCertificate, NIOSSLPrivateKey)] = [:]
|
||||
private var cacheOrder: [String] = []
|
||||
|
||||
private let keychainCAKeyTag = "com.treyt.proxyapp.ca.privatekey"
|
||||
private let keychainCACertTag = "com.treyt.proxyapp.ca.cert"
|
||||
|
||||
private init() {
|
||||
loadOrGenerateCA()
|
||||
loadOrGenerateCAIfNeeded()
|
||||
}
|
||||
|
||||
// MARK: - Public API
|
||||
@@ -32,7 +33,25 @@ public final class CertificateManager: @unchecked Sendable {
|
||||
public var hasCA: Bool {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
return rootCACert != nil
|
||||
refreshFromDiskLocked()
|
||||
return rootCACert != nil && rootCAKey != nil
|
||||
}
|
||||
|
||||
public var caFingerprint: String? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
refreshFromDiskLocked()
|
||||
return caFingerprintCache
|
||||
}
|
||||
|
||||
public var canGenerateCA: Bool {
|
||||
Bundle.main.infoDictionary?["NSExtension"] == nil
|
||||
}
|
||||
|
||||
public func reloadSharedCA() {
|
||||
lock.lock()
|
||||
refreshFromDiskLocked(force: true)
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
/// Get or generate a leaf certificate + key for the given domain.
|
||||
@@ -40,41 +59,57 @@ public final class CertificateManager: @unchecked Sendable {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
|
||||
refreshFromDiskLocked()
|
||||
|
||||
if let cached = certCache[domain] {
|
||||
cacheOrder.removeAll { $0 == domain }
|
||||
cacheOrder.append(domain)
|
||||
ProxyLogger.cert.debug("TLS context CACHE HIT for \(domain)")
|
||||
return try makeServerContext(cert: cached.0, key: cached.1)
|
||||
}
|
||||
|
||||
guard let caKey = rootCAKey, let caCert = rootCACert else {
|
||||
ProxyLogger.cert.error("TLS context FAILED for \(domain): no CA loaded. hasKey=\(self.rootCAKey != nil) hasCert=\(self.rootCACert != nil)")
|
||||
throw CertificateError.caNotFound
|
||||
}
|
||||
|
||||
let (leafCert, leafKey) = try generateLeaf(domain: domain, caKey: caKey, caCert: caCert)
|
||||
ProxyLogger.cert.info("TLS: generating leaf cert for \(domain), CA issuer=\(String(describing: caCert.subject))")
|
||||
|
||||
// Serialize to DER/PEM for NIOSSL
|
||||
var serializer = DER.Serializer()
|
||||
try leafCert.serialize(into: &serializer)
|
||||
let leafDER = serializer.serializedBytes
|
||||
let nioLeafCert = try NIOSSLCertificate(bytes: leafDER, format: .der)
|
||||
let leafKeyPEM = leafKey.pemRepresentation
|
||||
let nioLeafKey = try NIOSSLPrivateKey(bytes: [UInt8](leafKeyPEM.utf8), format: .pem)
|
||||
do {
|
||||
let (leafCert, leafKey) = try generateLeaf(domain: domain, caKey: caKey, caCert: caCert)
|
||||
ProxyLogger.cert.info("TLS: leaf cert generated for \(domain), SAN=\(domain), notBefore=\(leafCert.notValidBefore), notAfter=\(leafCert.notValidAfter)")
|
||||
|
||||
certCache[domain] = (nioLeafCert, nioLeafKey)
|
||||
cacheOrder.append(domain)
|
||||
var serializer = DER.Serializer()
|
||||
try leafCert.serialize(into: &serializer)
|
||||
let leafDER = serializer.serializedBytes
|
||||
ProxyLogger.cert.debug("TLS: leaf DER size=\(leafDER.count) bytes")
|
||||
|
||||
while cacheOrder.count > ProxyConstants.certificateCacheSize {
|
||||
let evicted = cacheOrder.removeFirst()
|
||||
certCache.removeValue(forKey: evicted)
|
||||
let nioLeafCert = try NIOSSLCertificate(bytes: leafDER, format: .der)
|
||||
let leafKeyPEM = leafKey.pemRepresentation
|
||||
let nioLeafKey = try NIOSSLPrivateKey(bytes: [UInt8](leafKeyPEM.utf8), format: .pem)
|
||||
|
||||
certCache[domain] = (nioLeafCert, nioLeafKey)
|
||||
cacheOrder.append(domain)
|
||||
|
||||
while cacheOrder.count > ProxyConstants.certificateCacheSize {
|
||||
let evicted = cacheOrder.removeFirst()
|
||||
certCache.removeValue(forKey: evicted)
|
||||
}
|
||||
|
||||
let ctx = try makeServerContext(cert: nioLeafCert, key: nioLeafKey)
|
||||
ProxyLogger.cert.info("TLS: server context READY for \(domain)")
|
||||
return ctx
|
||||
} catch {
|
||||
ProxyLogger.cert.error("TLS: leaf cert/context FAILED for \(domain): \(error)")
|
||||
throw error
|
||||
}
|
||||
|
||||
return try makeServerContext(cert: nioLeafCert, key: nioLeafKey)
|
||||
}
|
||||
|
||||
/// Export the root CA as DER data for user installation.
|
||||
public func exportCACertificateDER() -> [UInt8]? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
refreshFromDiskLocked()
|
||||
guard let cert = rootCACert else { return nil }
|
||||
var serializer = DER.Serializer()
|
||||
try? cert.serialize(into: &serializer)
|
||||
@@ -85,63 +120,173 @@ public final class CertificateManager: @unchecked Sendable {
|
||||
public func exportCACertificatePEM() -> String? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
refreshFromDiskLocked()
|
||||
guard let cert = rootCACert else { return nil }
|
||||
guard let pem = try? cert.serializeAsPEM() else { return nil }
|
||||
return pem.pemString
|
||||
}
|
||||
|
||||
public var caNotValidAfter: Date? {
|
||||
public var caGeneratedDate: Date? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
// notValidAfter is a Time, not directly a Date — we stored the date when generating
|
||||
return nil // Will be set properly after we store dates
|
||||
refreshFromDiskLocked()
|
||||
return rootCACert?.notValidBefore
|
||||
}
|
||||
|
||||
// MARK: - CA Generation
|
||||
public var caExpirationDate: Date? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
refreshFromDiskLocked()
|
||||
return rootCACert?.notValidAfter
|
||||
}
|
||||
|
||||
private func loadOrGenerateCA() {
|
||||
if loadCAFromKeychain() { return }
|
||||
public func regenerateCA() {
|
||||
guard canGenerateCA else {
|
||||
ProxyLogger.cert.error("Refusing to regenerate CA from extension context")
|
||||
return
|
||||
}
|
||||
|
||||
lock.lock()
|
||||
clearStateLocked()
|
||||
deleteStoredCALocked()
|
||||
|
||||
do {
|
||||
let key = P256.Signing.PrivateKey()
|
||||
let name = try DistinguishedName {
|
||||
CommonName("Proxy CA (\(deviceName()))")
|
||||
OrganizationName("ProxyApp")
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let twoYearsLater = now.addingTimeInterval(365 * 24 * 3600 * 2)
|
||||
|
||||
let extensions = try Certificate.Extensions {
|
||||
Critical(BasicConstraints.isCertificateAuthority(maxPathLength: 0))
|
||||
Critical(KeyUsage(keyCertSign: true, cRLSign: true))
|
||||
}
|
||||
|
||||
let cert = try Certificate(
|
||||
version: .v3,
|
||||
serialNumber: Certificate.SerialNumber(),
|
||||
publicKey: .init(key.publicKey),
|
||||
notValidBefore: now,
|
||||
notValidAfter: twoYearsLater,
|
||||
issuer: name,
|
||||
subject: name,
|
||||
signatureAlgorithm: .ecdsaWithSHA256,
|
||||
extensions: extensions,
|
||||
issuerPrivateKey: .init(key)
|
||||
)
|
||||
|
||||
self.rootCAKey = key
|
||||
self.rootCACert = cert
|
||||
|
||||
var serializer = DER.Serializer()
|
||||
try cert.serialize(into: &serializer)
|
||||
let der = serializer.serializedBytes
|
||||
self.rootCANIOSSL = try NIOSSLCertificate(bytes: der, format: .der)
|
||||
|
||||
saveCAToKeychain(key: key, certDER: der)
|
||||
print("[CertificateManager] Generated new root CA")
|
||||
try generateAndStoreCALocked()
|
||||
} catch {
|
||||
print("[CertificateManager] Failed to generate CA: \(error)")
|
||||
ProxyLogger.cert.error("CA regeneration failed: \(error.localizedDescription)")
|
||||
}
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
// MARK: - CA bootstrap
|
||||
|
||||
private func loadOrGenerateCAIfNeeded() {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
|
||||
refreshFromDiskLocked(force: true)
|
||||
guard rootCACert == nil || rootCAKey == nil else { return }
|
||||
|
||||
guard canGenerateCA else {
|
||||
ProxyLogger.cert.info("Shared CA not found; extension will remain passthrough-only")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try generateAndStoreCALocked()
|
||||
} catch {
|
||||
ProxyLogger.cert.error("Failed to generate shared CA: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
private func refreshFromDiskLocked(force: Bool = false) {
|
||||
let certURL = AppGroupPaths.caCertificateURL
|
||||
let keyURL = AppGroupPaths.caPrivateKeyURL
|
||||
|
||||
let certExists = FileManager.default.fileExists(atPath: certURL.path)
|
||||
let keyExists = FileManager.default.fileExists(atPath: keyURL.path)
|
||||
guard certExists, keyExists else {
|
||||
if rootCACert != nil || rootCAKey != nil {
|
||||
ProxyLogger.cert.info("Shared CA files missing; clearing in-memory state")
|
||||
}
|
||||
clearStateLocked()
|
||||
return
|
||||
}
|
||||
|
||||
let currentCertMTime = modificationDate(for: certURL)
|
||||
let currentKeyMTime = modificationDate(for: keyURL)
|
||||
if !force, currentCertMTime == certificateMTime, currentKeyMTime == keyMTime {
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
let certData = try Data(contentsOf: certURL)
|
||||
let keyData = try Data(contentsOf: keyURL)
|
||||
let key = try P256.Signing.PrivateKey(rawRepresentation: keyData)
|
||||
let cert = try Certificate(derEncoded: [UInt8](certData))
|
||||
let nioCert = try NIOSSLCertificate(bytes: [UInt8](certData), format: .der)
|
||||
|
||||
rootCAKey = key
|
||||
rootCACert = cert
|
||||
rootCANIOSSL = nioCert
|
||||
certificateMTime = currentCertMTime
|
||||
keyMTime = currentKeyMTime
|
||||
caFingerprintCache = fingerprint(for: certData)
|
||||
certCache.removeAll()
|
||||
cacheOrder.removeAll()
|
||||
ProxyLogger.cert.info("Loaded shared CA from App Group container")
|
||||
} catch {
|
||||
ProxyLogger.cert.error("Failed to load shared CA: \(error.localizedDescription)")
|
||||
clearStateLocked()
|
||||
}
|
||||
}
|
||||
|
||||
private func generateAndStoreCALocked() throws {
|
||||
let key = P256.Signing.PrivateKey()
|
||||
let name = try DistinguishedName {
|
||||
CommonName("Proxy CA (\(deviceName()))")
|
||||
OrganizationName("ProxyApp")
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let twoYearsLater = now.addingTimeInterval(365 * 24 * 3600 * 2)
|
||||
|
||||
let extensions = try Certificate.Extensions {
|
||||
Critical(BasicConstraints.isCertificateAuthority(maxPathLength: 0))
|
||||
Critical(KeyUsage(keyCertSign: true, cRLSign: true))
|
||||
}
|
||||
|
||||
let cert = try Certificate(
|
||||
version: .v3,
|
||||
serialNumber: Certificate.SerialNumber(),
|
||||
publicKey: .init(key.publicKey),
|
||||
notValidBefore: now,
|
||||
notValidAfter: twoYearsLater,
|
||||
issuer: name,
|
||||
subject: name,
|
||||
signatureAlgorithm: .ecdsaWithSHA256,
|
||||
extensions: extensions,
|
||||
issuerPrivateKey: .init(key)
|
||||
)
|
||||
|
||||
var serializer = DER.Serializer()
|
||||
try cert.serialize(into: &serializer)
|
||||
let der = serializer.serializedBytes
|
||||
|
||||
try FileManager.default.createDirectory(
|
||||
at: AppGroupPaths.certificatesDirectory,
|
||||
withIntermediateDirectories: true,
|
||||
attributes: nil
|
||||
)
|
||||
try Data(der).write(to: AppGroupPaths.caCertificateURL, options: .atomic)
|
||||
try key.rawRepresentation.write(to: AppGroupPaths.caPrivateKeyURL, options: .atomic)
|
||||
|
||||
rootCAKey = key
|
||||
rootCACert = cert
|
||||
rootCANIOSSL = try NIOSSLCertificate(bytes: der, format: .der)
|
||||
certificateMTime = modificationDate(for: AppGroupPaths.caCertificateURL)
|
||||
keyMTime = modificationDate(for: AppGroupPaths.caPrivateKeyURL)
|
||||
caFingerprintCache = fingerprint(for: Data(der))
|
||||
certCache.removeAll()
|
||||
cacheOrder.removeAll()
|
||||
|
||||
ProxyLogger.cert.info("Generated new shared root CA")
|
||||
}
|
||||
|
||||
private func clearStateLocked() {
|
||||
rootCAKey = nil
|
||||
rootCACert = nil
|
||||
rootCANIOSSL = nil
|
||||
caFingerprintCache = nil
|
||||
certificateMTime = nil
|
||||
keyMTime = nil
|
||||
certCache.removeAll()
|
||||
cacheOrder.removeAll()
|
||||
}
|
||||
|
||||
private func deleteStoredCALocked() {
|
||||
for url in [AppGroupPaths.caCertificateURL, AppGroupPaths.caPrivateKeyURL] {
|
||||
try? FileManager.default.removeItem(at: url)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,81 +344,16 @@ public final class CertificateManager: @unchecked Sendable {
|
||||
return try NIOSSLContext(configuration: config)
|
||||
}
|
||||
|
||||
// MARK: - Keychain
|
||||
|
||||
private func loadCAFromKeychain() -> Bool {
|
||||
let keyQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCAKeyTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecReturnData as String: true
|
||||
]
|
||||
var keyResult: AnyObject?
|
||||
guard SecItemCopyMatching(keyQuery as CFDictionary, &keyResult) == errSecSuccess,
|
||||
let keyData = keyResult as? Data else { return false }
|
||||
|
||||
let certQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCACertTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecReturnData as String: true
|
||||
]
|
||||
var certResult: AnyObject?
|
||||
guard SecItemCopyMatching(certQuery as CFDictionary, &certResult) == errSecSuccess,
|
||||
let certData = certResult as? Data else { return false }
|
||||
|
||||
do {
|
||||
let key = try P256.Signing.PrivateKey(rawRepresentation: keyData)
|
||||
let cert = try Certificate(derEncoded: [UInt8](certData))
|
||||
let nioCert = try NIOSSLCertificate(bytes: [UInt8](certData), format: .der)
|
||||
|
||||
self.rootCAKey = key
|
||||
self.rootCACert = cert
|
||||
self.rootCANIOSSL = nioCert
|
||||
print("[CertificateManager] Loaded CA from Keychain")
|
||||
return true
|
||||
} catch {
|
||||
print("[CertificateManager] Failed to load CA from Keychain: \(error)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func saveCAToKeychain(key: P256.Signing.PrivateKey, certDER: [UInt8]) {
|
||||
let keyData = key.rawRepresentation
|
||||
|
||||
// Delete existing entries
|
||||
for tag in [keychainCAKeyTag, keychainCACertTag] {
|
||||
let deleteQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: tag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier
|
||||
]
|
||||
SecItemDelete(deleteQuery as CFDictionary)
|
||||
}
|
||||
|
||||
// Save key
|
||||
let addKeyQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCAKeyTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecValueData as String: keyData,
|
||||
kSecAttrAccessible as String: kSecAttrAccessibleAfterFirstUnlock
|
||||
]
|
||||
SecItemAdd(addKeyQuery as CFDictionary, nil)
|
||||
|
||||
// Save cert
|
||||
let addCertQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCACertTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecValueData as String: Data(certDER),
|
||||
kSecAttrAccessible as String: kSecAttrAccessibleAfterFirstUnlock
|
||||
]
|
||||
SecItemAdd(addCertQuery as CFDictionary, nil)
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func modificationDate(for url: URL) -> Date? {
|
||||
(try? FileManager.default.attributesOfItem(atPath: url.path)[.modificationDate] as? Date) ?? nil
|
||||
}
|
||||
|
||||
private func fingerprint(for data: Data) -> String {
|
||||
SHA256.hash(data: data).map { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func deviceName() -> String {
|
||||
#if canImport(UIKit)
|
||||
return UIDevice.current.name
|
||||
@@ -283,8 +363,6 @@ public final class CertificateManager: @unchecked Sendable {
|
||||
}
|
||||
|
||||
public enum CertificateError: Error {
|
||||
case notImplemented
|
||||
case generationFailed
|
||||
case caNotFound
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,19 +4,21 @@ import NIOPosix
|
||||
import NIOHTTP1
|
||||
|
||||
/// Handles incoming proxy requests:
|
||||
/// - HTTP CONNECT → establishes TCP tunnel (GlueHandler passthrough, or MITM in Phase 3)
|
||||
/// - Plain HTTP → connects upstream, forwards request, captures request+response
|
||||
/// - HTTP CONNECT -> TCP tunnel (GlueHandler passthrough or MITM)
|
||||
/// - Plain HTTP -> forward with capture
|
||||
final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPServerRequestPart
|
||||
typealias OutboundOut = HTTPServerResponsePart
|
||||
|
||||
private let trafficRepo: TrafficRepository
|
||||
private let runtimeStatusRepo = RuntimeStatusRepository()
|
||||
|
||||
private var pendingConnectHead: HTTPRequestHead?
|
||||
private var pendingConnectBytes: [ByteBuffer] = []
|
||||
|
||||
// Buffer request parts until we've connected upstream
|
||||
private var pendingHead: HTTPRequestHead?
|
||||
private var pendingBody: [ByteBuffer] = []
|
||||
private var pendingEnd: HTTPHeaders?
|
||||
private var receivedEnd = false
|
||||
|
||||
init(trafficRepo: TrafficRepository) {
|
||||
self.trafficRepo = trafficRepo
|
||||
@@ -28,145 +30,400 @@ final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
switch part {
|
||||
case .head(let head):
|
||||
if head.method == .CONNECT {
|
||||
handleConnect(context: context, head: head)
|
||||
ProxyLogger.connect.info("CONNECT \(head.uri)")
|
||||
pendingConnectHead = head
|
||||
pendingConnectBytes.removeAll()
|
||||
} else {
|
||||
ProxyLogger.connect.info("HTTP \(head.method.rawValue) \(head.uri)")
|
||||
pendingHead = head
|
||||
pendingBody.removeAll()
|
||||
pendingEnd = nil
|
||||
}
|
||||
|
||||
case .body(let buffer):
|
||||
pendingBody.append(buffer)
|
||||
if pendingConnectHead != nil {
|
||||
pendingConnectBytes.append(buffer)
|
||||
} else {
|
||||
pendingBody.append(buffer)
|
||||
}
|
||||
|
||||
case .end(let trailers):
|
||||
if let connectHead = pendingConnectHead {
|
||||
let bufferedBytes = pendingConnectBytes
|
||||
pendingConnectHead = nil
|
||||
pendingConnectBytes.removeAll()
|
||||
handleConnect(context: context, head: connectHead, initialBuffers: bufferedBytes)
|
||||
return
|
||||
}
|
||||
|
||||
if pendingHead != nil {
|
||||
pendingEnd = trailers
|
||||
receivedEnd = true
|
||||
handleHTTPRequest(context: context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - CONNECT (HTTPS tunnel)
|
||||
// MARK: - CONNECT
|
||||
|
||||
private func handleConnect(context: ChannelHandlerContext, head: HTTPRequestHead) {
|
||||
private func handleConnect(
|
||||
context: ChannelHandlerContext,
|
||||
head: HTTPRequestHead,
|
||||
initialBuffers: [ByteBuffer]
|
||||
) {
|
||||
let components = head.uri.split(separator: ":")
|
||||
let host = String(components[0])
|
||||
let originalHost = String(components[0])
|
||||
let port = components.count > 1 ? Int(components[1]) ?? 443 : 443
|
||||
let connectURL = "https://\(originalHost):\(port)"
|
||||
|
||||
// Check if this domain should be MITM'd (SSL Proxying enabled + domain in include list)
|
||||
let shouldMITM = shouldInterceptSSL(domain: host)
|
||||
|
||||
// Send 200 Connection Established
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
|
||||
if shouldMITM {
|
||||
// MITM mode: strip HTTP handlers, install MITMHandler
|
||||
setupMITM(context: context, host: host, port: port)
|
||||
} else {
|
||||
// Passthrough mode: record domain-level entry, tunnel raw bytes
|
||||
recordConnectTraffic(host: host, port: port)
|
||||
|
||||
// We don't need to connect upstream ourselves — GlueHandler does raw forwarding
|
||||
// But GlueHandler pairs two channels, so we need the remote channel first
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.connect(host: host, port: port)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let remoteChannel):
|
||||
self.setupGlue(context: context, remoteChannel: remoteChannel)
|
||||
case .failure(let error):
|
||||
print("[Proxy] CONNECT passthrough failed to \(host):\(port): \(error)")
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func shouldInterceptSSL(domain: String) -> Bool {
|
||||
guard IPCManager.shared.isSSLProxyingEnabled else { return false }
|
||||
guard CertificateManager.shared.hasCA else { return false }
|
||||
|
||||
// Check SSL proxying list from database
|
||||
let rulesRepo = RulesRepository()
|
||||
do {
|
||||
let entries = try rulesRepo.fetchAllSSLEntries()
|
||||
|
||||
// Check exclude list first
|
||||
for entry in entries where !entry.isInclude {
|
||||
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
||||
return false
|
||||
if let blockAction = RulesEngine.checkBlockList(url: connectURL, method: "CONNECT"),
|
||||
blockAction != .hideOnly {
|
||||
ProxyLogger.connect.info("BLOCKED \(originalHost) action=\(blockAction.rawValue)")
|
||||
if blockAction == .blockAndDisplay {
|
||||
var traffic = CapturedTraffic(
|
||||
domain: originalHost, url: connectURL, method: "CONNECT", scheme: "https",
|
||||
statusCode: 403, statusText: "Blocked",
|
||||
startedAt: Date().timeIntervalSince1970,
|
||||
completedAt: Date().timeIntervalSince1970, durationMs: 0, isSslDecrypted: false
|
||||
)
|
||||
do {
|
||||
try trafficRepo.insert(&traffic)
|
||||
IPCManager.shared.post(.newTrafficCaptured)
|
||||
} catch {
|
||||
ProxyLogger.db.error("DB insert blocked traffic failed: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
// Check include list
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .forbidden)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
context.close(promise: nil)
|
||||
return
|
||||
}
|
||||
|
||||
let upstreamHost = RulesEngine.checkDNSSpoof(domain: originalHost) ?? originalHost
|
||||
let shouldMITM = shouldInterceptSSL(domain: originalHost)
|
||||
let shouldHide = shouldHideConnect(url: connectURL, host: originalHost)
|
||||
ProxyLogger.connect.info("=== CONNECT original=\(originalHost) upstream=\(upstreamHost):\(port) mitm=\(shouldMITM) ===")
|
||||
|
||||
if shouldMITM {
|
||||
upgradeToMITM(
|
||||
context: context,
|
||||
originalHost: originalHost,
|
||||
upstreamHost: upstreamHost,
|
||||
port: port,
|
||||
initialBuffers: initialBuffers
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelOption(.autoRead, value: false)
|
||||
.connect(host: upstreamHost, port: port)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let remoteChannel):
|
||||
ProxyLogger.connect.info("Upstream connected to \(upstreamHost):\(port), upgrading to passthrough")
|
||||
self.upgradeToPassthrough(
|
||||
context: context,
|
||||
remoteChannel: remoteChannel,
|
||||
originalHost: originalHost,
|
||||
upstreamHost: upstreamHost,
|
||||
port: port,
|
||||
initialBuffers: initialBuffers,
|
||||
isHidden: shouldHide
|
||||
)
|
||||
|
||||
case .failure(let error):
|
||||
ProxyLogger.connect.error("Upstream connect FAILED \(upstreamHost):\(port): \(error.localizedDescription)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastConnectError = "CONNECT \(originalHost): \(error.localizedDescription)"
|
||||
}
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .badGateway)
|
||||
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func shouldInterceptSSL(domain: String) -> Bool {
|
||||
let sslEnabled = IPCManager.shared.isSSLProxyingEnabled
|
||||
let hasCA = CertificateManager.shared.hasCA
|
||||
ProxyLogger.connect.info("shouldInterceptSSL(\(domain)): sslEnabled=\(sslEnabled) hasCA=\(hasCA)")
|
||||
|
||||
// Write diagnostic info so the app can display what the extension sees
|
||||
runtimeStatusRepo.update {
|
||||
$0.caFingerprint = CertificateManager.shared.caFingerprint
|
||||
$0.lastConnectError = "SSL check: domain=\(domain) sslEnabled=\(sslEnabled) hasCA=\(hasCA)"
|
||||
}
|
||||
|
||||
guard sslEnabled else {
|
||||
ProxyLogger.connect.info("SSL proxying DISABLED globally — skipping MITM")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "SSL proxying disabled (sslEnabled=false in DB)"
|
||||
}
|
||||
return false
|
||||
}
|
||||
guard hasCA else {
|
||||
ProxyLogger.connect.info("Shared CA unavailable in extension — skipping MITM")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "No CA in extension (hasCA=false)"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if domain was auto-detected as using SSL pinning
|
||||
if PinnedDomainRepository().isPinned(domain: domain) {
|
||||
ProxyLogger.connect.info("SSL PINNED (auto-detected): \(domain) — using passthrough")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "Pinned domain (auto-fallback): \(domain)"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
let rulesRepo = RulesRepository()
|
||||
do {
|
||||
let entries = try rulesRepo.fetchAllSSLEntries()
|
||||
let includeCount = entries.filter(\.isInclude).count
|
||||
let excludeCount = entries.filter { !$0.isInclude }.count
|
||||
let patterns = entries.map { "\($0.isInclude ? "+" : "-")\($0.domainPattern)" }.joined(separator: ", ")
|
||||
ProxyLogger.connect.info("SSL entries: \(entries.count) (include=\(includeCount) exclude=\(excludeCount)) patterns=[\(patterns)]")
|
||||
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastConnectError = "SSL rules: \(entries.count) entries [\(patterns)] checking domain=\(domain)"
|
||||
}
|
||||
|
||||
for entry in entries where !entry.isInclude {
|
||||
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
||||
ProxyLogger.connect.debug("SSL EXCLUDED by pattern: \(entry.domainPattern)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
for entry in entries where entry.isInclude {
|
||||
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
||||
ProxyLogger.connect.info("SSL INCLUDED by pattern: \(entry.domainPattern) -> MITM ON")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = nil
|
||||
$0.lastConnectError = "MITM enabled for \(domain) via pattern \(entry.domainPattern)"
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[Proxy] Failed to check SSL proxying list: \(error)")
|
||||
ProxyLogger.connect.error("SSL list fetch failed: \(error.localizedDescription)")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "SSL list DB error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
|
||||
ProxyLogger.connect.debug("SSL: no matching rule for \(domain)")
|
||||
return false
|
||||
}
|
||||
|
||||
private func setupMITM(context: ChannelHandlerContext, host: String, port: Int) {
|
||||
let mitmHandler = MITMHandler(host: host, port: port, trafficRepo: trafficRepo)
|
||||
private func shouldHideConnect(url: String, host: String) -> Bool {
|
||||
if let blockAction = RulesEngine.checkBlockList(url: url, method: "CONNECT"), blockAction == .hideOnly {
|
||||
return true
|
||||
}
|
||||
return IPCManager.shared.hideSystemTraffic && SystemTrafficFilter.isSystemDomain(host)
|
||||
}
|
||||
|
||||
// Remove HTTP handlers, keep raw bytes for MITMHandler
|
||||
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
|
||||
.whenSuccess { handler in
|
||||
context.channel.pipeline.removeHandler(handler, promise: nil)
|
||||
}
|
||||
private func upgradeToMITM(
|
||||
context: ChannelHandlerContext,
|
||||
originalHost: String,
|
||||
upstreamHost: String,
|
||||
port: Int,
|
||||
initialBuffers: [ByteBuffer]
|
||||
) {
|
||||
let channel = context.channel
|
||||
|
||||
context.pipeline.removeHandler(context: context).whenComplete { _ in
|
||||
context.channel.pipeline.addHandler(mitmHandler).whenFailure { error in
|
||||
print("[Proxy] Failed to install MITM handler: \(error)")
|
||||
context.close(promise: nil)
|
||||
channel.setOption(.autoRead, value: false).flatMap {
|
||||
self.upgradeClientChannelToRaw(channel)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(
|
||||
MITMHandler(
|
||||
originalHost: originalHost,
|
||||
upstreamHost: upstreamHost,
|
||||
port: port,
|
||||
trafficRepo: self.trafficRepo
|
||||
)
|
||||
)
|
||||
}.flatMap {
|
||||
self.sendConnectEstablished(on: channel)
|
||||
}.whenComplete { result in
|
||||
switch result {
|
||||
case .success:
|
||||
ProxyLogger.connect.info("MITM pipeline ready for \(originalHost):\(port)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastConnectError = nil
|
||||
$0.lastMITMError = nil
|
||||
}
|
||||
channel.setOption(.autoRead, value: true).whenComplete { _ in
|
||||
channel.read()
|
||||
for buffer in initialBuffers {
|
||||
channel.pipeline.fireChannelRead(NIOAny(buffer))
|
||||
}
|
||||
}
|
||||
|
||||
case .failure(let error):
|
||||
ProxyLogger.connect.error("MITM upgrade FAILED for \(originalHost):\(port): \(error.localizedDescription)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "MITM setup \(originalHost): \(error.localizedDescription)"
|
||||
}
|
||||
channel.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func setupGlue(context: ChannelHandlerContext, remoteChannel: Channel) {
|
||||
private func upgradeToPassthrough(
|
||||
context: ChannelHandlerContext,
|
||||
remoteChannel: Channel,
|
||||
originalHost: String,
|
||||
upstreamHost: String,
|
||||
port: Int,
|
||||
initialBuffers: [ByteBuffer],
|
||||
isHidden: Bool
|
||||
) {
|
||||
let channel = context.channel
|
||||
let localGlue = GlueHandler()
|
||||
let remoteGlue = GlueHandler()
|
||||
localGlue.partner = remoteGlue
|
||||
remoteGlue.partner = localGlue
|
||||
|
||||
// Remove all HTTP handlers from the client channel, leaving raw bytes
|
||||
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
|
||||
.whenSuccess { handler in
|
||||
context.channel.pipeline.removeHandler(handler, promise: nil)
|
||||
}
|
||||
|
||||
context.pipeline.removeHandler(context: context).whenComplete { _ in
|
||||
context.channel.pipeline.addHandler(localGlue).whenSuccess {
|
||||
remoteChannel.pipeline.addHandler(remoteGlue).whenFailure { _ in
|
||||
context.close(promise: nil)
|
||||
remoteChannel.close(promise: nil)
|
||||
channel.setOption(.autoRead, value: false).flatMap {
|
||||
self.upgradeClientChannelToRaw(channel)
|
||||
}.flatMap {
|
||||
remoteChannel.pipeline.addHandler(remoteGlue)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(localGlue)
|
||||
}.flatMap {
|
||||
self.sendConnectEstablished(on: channel)
|
||||
}.whenComplete { result in
|
||||
switch result {
|
||||
case .success:
|
||||
ProxyLogger.connect.info("Passthrough tunnel ready for \(originalHost):\(port) via \(upstreamHost)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastConnectError = nil
|
||||
}
|
||||
self.recordConnectTraffic(host: originalHost, port: port, isHidden: isHidden)
|
||||
|
||||
for buffer in initialBuffers {
|
||||
remoteChannel.write(NIOAny(buffer), promise: nil)
|
||||
}
|
||||
remoteChannel.flush()
|
||||
|
||||
channel.setOption(.autoRead, value: true).whenComplete { _ in
|
||||
channel.read()
|
||||
}
|
||||
remoteChannel.setOption(.autoRead, value: true).whenComplete { _ in
|
||||
remoteChannel.read()
|
||||
}
|
||||
|
||||
case .failure(let error):
|
||||
ProxyLogger.connect.error("Passthrough upgrade FAILED for \(originalHost):\(port): \(error.localizedDescription)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastConnectError = "Passthrough \(originalHost): \(error.localizedDescription)"
|
||||
}
|
||||
channel.close(promise: nil)
|
||||
remoteChannel.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Plain HTTP forwarding
|
||||
private func upgradeClientChannelToRaw(_ channel: Channel) -> EventLoopFuture<Void> {
|
||||
removeHandler(ByteToMessageHandler<HTTPRequestDecoder>.self, from: channel).flatMap { _ in
|
||||
self.removeHandler(HTTPResponseEncoder.self, from: channel)
|
||||
}.flatMap { _ in
|
||||
channel.pipeline.removeHandler(self)
|
||||
}
|
||||
}
|
||||
|
||||
private func sendConnectEstablished(on channel: Channel) -> EventLoopFuture<Void> {
|
||||
var buffer = channel.allocator.buffer(capacity: 64)
|
||||
buffer.writeString("HTTP/1.1 200 Connection Established\r\n\r\n")
|
||||
return channel.writeAndFlush(NIOAny(buffer))
|
||||
}
|
||||
|
||||
private func removeHandler<H: RemovableChannelHandler>(_ type: H.Type, from channel: Channel) -> EventLoopFuture<Void> {
|
||||
channel.pipeline.handler(type: type).flatMap { handler in
|
||||
channel.pipeline.removeHandler(handler)
|
||||
}.recover { _ in () }
|
||||
}
|
||||
|
||||
// MARK: - Plain HTTP
|
||||
|
||||
private func handleHTTPRequest(context: ChannelHandlerContext) {
|
||||
guard let head = pendingHead else { return }
|
||||
|
||||
// Parse host and port from the absolute URI or Host header
|
||||
guard let (host, port, path) = parseHTTPTarget(head: head) else {
|
||||
ProxyLogger.connect.error("HTTP: failed to parse target from \(head.uri)")
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .badRequest)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
pendingHead = nil
|
||||
pendingBody.removeAll()
|
||||
pendingEnd = nil
|
||||
return
|
||||
}
|
||||
|
||||
let fullURL = "http://\(host)\(path)"
|
||||
let method = head.method.rawValue
|
||||
let upstreamHost = RulesEngine.checkDNSSpoof(domain: host) ?? host
|
||||
ProxyLogger.connect.info("HTTP FORWARD \(method) \(fullURL)")
|
||||
|
||||
if let blockAction = RulesEngine.checkBlockList(url: fullURL, method: method),
|
||||
blockAction != .hideOnly {
|
||||
ProxyLogger.connect.info("HTTP BLOCKED \(fullURL) action=\(blockAction.rawValue)")
|
||||
if blockAction == .blockAndDisplay {
|
||||
var traffic = CapturedTraffic(
|
||||
domain: host, url: fullURL, method: method, scheme: "http",
|
||||
statusCode: 403, statusText: "Blocked",
|
||||
startedAt: Date().timeIntervalSince1970,
|
||||
completedAt: Date().timeIntervalSince1970, durationMs: 0, isSslDecrypted: false
|
||||
)
|
||||
do {
|
||||
try trafficRepo.insert(&traffic)
|
||||
IPCManager.shared.post(.newTrafficCaptured)
|
||||
} catch {
|
||||
ProxyLogger.db.error("DB insert failed: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .forbidden)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
pendingHead = nil
|
||||
pendingBody.removeAll()
|
||||
pendingEnd = nil
|
||||
return
|
||||
}
|
||||
|
||||
if let mapRule = RulesEngine.checkMapLocal(url: fullURL, method: method) {
|
||||
ProxyLogger.connect.info("MAP LOCAL match for \(fullURL) -> status \(mapRule.responseStatus)")
|
||||
let status = HTTPResponseStatus(statusCode: mapRule.responseStatus)
|
||||
var headers = decodeHeaders(mapRule.responseHeaders)
|
||||
if let ct = mapRule.responseContentType, !ct.isEmpty {
|
||||
headers.replaceOrAdd(name: "Content-Type", value: ct)
|
||||
}
|
||||
let bodyData = mapRule.responseBody
|
||||
if let bodyData, !bodyData.isEmpty {
|
||||
headers.replaceOrAdd(name: "Content-Length", value: "\(bodyData.count)")
|
||||
}
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: status, headers: headers)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
if let bodyData, !bodyData.isEmpty {
|
||||
var buffer = context.channel.allocator.buffer(capacity: bodyData.count)
|
||||
buffer.writeBytes(bodyData)
|
||||
context.write(wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil)
|
||||
}
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
pendingHead = nil
|
||||
pendingBody.removeAll()
|
||||
pendingEnd = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite the request URI to relative path (upstream expects /path, not http://host/path)
|
||||
var upstreamHead = head
|
||||
upstreamHead.uri = path
|
||||
// Ensure Host header is set
|
||||
if !upstreamHead.headers.contains(name: "Host") {
|
||||
upstreamHead.headers.add(name: "Host", value: host)
|
||||
}
|
||||
@@ -176,7 +433,6 @@ final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
// Remote channel: decode HTTP responses, encode HTTP requests
|
||||
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
|
||||
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
|
||||
}.flatMap {
|
||||
@@ -187,80 +443,90 @@ final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
)
|
||||
}
|
||||
}
|
||||
.connect(host: host, port: port)
|
||||
.connect(host: upstreamHost, port: port)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let remoteChannel):
|
||||
// Forward the buffered request to upstream
|
||||
ProxyLogger.connect.info("HTTP upstream connected to \(upstreamHost):\(port), forwarding request")
|
||||
remoteChannel.write(NIOAny(HTTPClientRequestPart.head(upstreamHead)), promise: nil)
|
||||
for bodyBuffer in self.pendingBody {
|
||||
remoteChannel.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(bodyBuffer))), promise: nil)
|
||||
}
|
||||
remoteChannel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(self.pendingEnd)), promise: nil)
|
||||
|
||||
// Clear buffered data
|
||||
self.pendingHead = nil
|
||||
self.pendingBody.removeAll()
|
||||
self.pendingEnd = nil
|
||||
|
||||
case .failure(let error):
|
||||
print("[Proxy] HTTP forward failed to \(host):\(port): \(error)")
|
||||
ProxyLogger.connect.error("HTTP upstream connect FAILED \(host):\(port): \(error.localizedDescription)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastConnectError = "HTTP \(fullURL): \(error.localizedDescription)"
|
||||
}
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .badGateway)
|
||||
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
|
||||
self.pendingHead = nil
|
||||
self.pendingBody.removeAll()
|
||||
self.pendingEnd = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func decodeHeaders(_ json: String?) -> HTTPHeaders {
|
||||
guard let json,
|
||||
let data = json.data(using: .utf8),
|
||||
let dict = try? JSONDecoder().decode([String: String].self, from: data) else {
|
||||
return HTTPHeaders()
|
||||
}
|
||||
|
||||
var headers = HTTPHeaders()
|
||||
for (name, value) in dict {
|
||||
headers.add(name: name, value: value)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// MARK: - URL Parsing
|
||||
|
||||
private func parseHTTPTarget(head: HTTPRequestHead) -> (host: String, port: Int, path: String)? {
|
||||
// Absolute URI: "http://example.com:8080/path?query"
|
||||
if head.uri.hasPrefix("http://") || head.uri.hasPrefix("https://") {
|
||||
guard let url = URLComponents(string: head.uri) else { return nil }
|
||||
let host = url.host ?? ""
|
||||
let port = url.port ?? (head.uri.hasPrefix("https") ? 443 : 80)
|
||||
var path = url.path.isEmpty ? "/" : url.path
|
||||
if let query = url.query {
|
||||
path += "?\(query)"
|
||||
}
|
||||
if let query = url.query { path += "?\(query)" }
|
||||
return (host, port, path)
|
||||
}
|
||||
|
||||
// Relative URI with Host header
|
||||
if let hostHeader = head.headers.first(name: "Host") {
|
||||
let parts = hostHeader.split(separator: ":")
|
||||
let host = String(parts[0])
|
||||
let port = parts.count > 1 ? Int(parts[1]) ?? 80 : 80
|
||||
return (host, port, head.uri)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - CONNECT traffic recording
|
||||
|
||||
private func recordConnectTraffic(host: String, port: Int) {
|
||||
private func recordConnectTraffic(host: String, port: Int, isHidden: Bool) {
|
||||
var traffic = CapturedTraffic(
|
||||
domain: host,
|
||||
url: "https://\(host):\(port)",
|
||||
method: "CONNECT",
|
||||
scheme: "https",
|
||||
statusCode: 200,
|
||||
statusText: "Connection Established",
|
||||
startedAt: Date().timeIntervalSince1970,
|
||||
completedAt: Date().timeIntervalSince1970,
|
||||
durationMs: 0,
|
||||
isSslDecrypted: false
|
||||
domain: host, url: "https://\(host):\(port)", method: "CONNECT", scheme: "https",
|
||||
statusCode: 200, statusText: "Connection Established",
|
||||
startedAt: Date().timeIntervalSince1970, completedAt: Date().timeIntervalSince1970,
|
||||
durationMs: 0, isSslDecrypted: false, isHidden: isHidden
|
||||
)
|
||||
try? trafficRepo.insert(&traffic)
|
||||
IPCManager.shared.post(.newTrafficCaptured)
|
||||
do {
|
||||
try trafficRepo.insert(&traffic)
|
||||
ProxyLogger.db.debug("Recorded CONNECT \(host) (hidden=\(isHidden))")
|
||||
} catch {
|
||||
ProxyLogger.db.error("Failed to record CONNECT \(host): \(error.localizedDescription)")
|
||||
}
|
||||
NotificationThrottle.shared.throttle()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - HTTPRelayHandler
|
||||
|
||||
/// Relays HTTP responses from the upstream server back to the proxy client.
|
||||
final class HTTPRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPClientResponsePart
|
||||
|
||||
@@ -274,11 +540,10 @@ final class HTTPRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
let serverHead = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
|
||||
clientContext.write(wrapResponse(.head(serverHead)), promise: nil)
|
||||
ProxyLogger.connect.debug("HTTPRelay response: \(head.status.code)")
|
||||
clientContext.write(wrapResponse(.head(HTTPResponseHead(version: head.version, status: head.status, headers: head.headers))), promise: nil)
|
||||
case .body(let buffer):
|
||||
clientContext.write(wrapResponse(.body(.byteBuffer(buffer))), promise: nil)
|
||||
case .end(let trailers):
|
||||
@@ -287,11 +552,12 @@ final class HTTPRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
ProxyLogger.connect.debug("HTTPRelay: remote channel inactive")
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
print("[Proxy] Relay error: \(error)")
|
||||
ProxyLogger.connect.error("HTTPRelay error: \(error.localizedDescription)")
|
||||
context.close(promise: nil)
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import Foundation
|
||||
import NIOCore
|
||||
|
||||
/// Bidirectional TCP forwarder. Pairs two channels so bytes flow in both directions.
|
||||
/// Used for CONNECT tunneling (passthrough mode, no MITM).
|
||||
final class GlueHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias OutboundOut = ByteBuffer
|
||||
@@ -13,15 +12,19 @@ final class GlueHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
|
||||
func handlerAdded(context: ChannelHandlerContext) {
|
||||
self.context = context
|
||||
ProxyLogger.glue.debug("GlueHandler added to \(context.channel.localAddress?.description ?? "?")")
|
||||
}
|
||||
|
||||
func handlerRemoved(context: ChannelHandlerContext) {
|
||||
ProxyLogger.glue.debug("GlueHandler removed from \(context.channel.localAddress?.description ?? "?")")
|
||||
self.context = nil
|
||||
self.partner = nil
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
partner?.write(unwrapInboundIn(data))
|
||||
let buf = unwrapInboundIn(data)
|
||||
ProxyLogger.glue.debug("GlueHandler read \(buf.readableBytes) bytes, forwarding to partner")
|
||||
partner?.write(buf)
|
||||
}
|
||||
|
||||
func channelReadComplete(context: ChannelHandlerContext) {
|
||||
@@ -29,10 +32,12 @@ final class GlueHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
ProxyLogger.glue.debug("GlueHandler channelInactive — closing partner")
|
||||
partner?.close()
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
ProxyLogger.glue.error("GlueHandler error: \(error.localizedDescription)")
|
||||
context.close(promise: nil)
|
||||
}
|
||||
|
||||
@@ -42,8 +47,6 @@ final class GlueHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Partner operations
|
||||
|
||||
private func write(_ buffer: ByteBuffer) {
|
||||
context?.write(wrapOutboundOut(buffer), promise: nil)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ import NIOCore
|
||||
import NIOHTTP1
|
||||
|
||||
/// Captures HTTP request/response pairs and writes them to the traffic database.
|
||||
/// Inserted into the pipeline after TLS termination (MITM) or for plain HTTP.
|
||||
final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
typealias InboundIn = HTTPClientResponsePart
|
||||
typealias InboundOut = HTTPClientResponsePart
|
||||
@@ -21,10 +20,14 @@ final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
private var responseBody = Data()
|
||||
private var requestStartTime: Double = 0
|
||||
|
||||
private let hardcodedDebugDomain = "okcupid"
|
||||
private let hardcodedDebugNeedle = "jill"
|
||||
|
||||
init(trafficRepo: TrafficRepository, domain: String, scheme: String = "https") {
|
||||
self.trafficRepo = trafficRepo
|
||||
self.domain = domain
|
||||
self.scheme = scheme
|
||||
ProxyLogger.capture.debug("HTTPCaptureHandler created for \(domain) (\(scheme))")
|
||||
}
|
||||
|
||||
// MARK: - Outbound (Request)
|
||||
@@ -33,16 +36,29 @@ final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
let part = unwrapOutboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
case .head(var head):
|
||||
currentRequestId = UUID().uuidString
|
||||
requestHead = head
|
||||
requestBody = Data()
|
||||
requestStartTime = Date().timeIntervalSince1970
|
||||
|
||||
if RulesEngine.shouldStripCache() {
|
||||
head.headers.remove(name: "If-Modified-Since")
|
||||
head.headers.remove(name: "If-None-Match")
|
||||
head.headers.replaceOrAdd(name: "Cache-Control", value: "no-cache")
|
||||
head.headers.replaceOrAdd(name: "Pragma", value: "no-cache")
|
||||
}
|
||||
|
||||
requestHead = head
|
||||
ProxyLogger.capture.info("CAPTURE REQ \(head.method.rawValue) \(self.scheme)://\(self.domain)\(head.uri)")
|
||||
context.write(self.wrapOutboundOut(.head(head)), promise: promise)
|
||||
return
|
||||
case .body(.byteBuffer(let buffer)):
|
||||
if requestBody.count < ProxyConstants.maxBodySizeBytes {
|
||||
requestBody.append(contentsOf: buffer.readableBytesView)
|
||||
}
|
||||
ProxyLogger.capture.debug("CAPTURE REQ body chunk: \(buffer.readableBytes) bytes (total: \(self.requestBody.count))")
|
||||
case .end:
|
||||
ProxyLogger.capture.debug("CAPTURE REQ end — saving to DB")
|
||||
saveRequest()
|
||||
default:
|
||||
break
|
||||
@@ -57,14 +73,27 @@ final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
case .head(var head):
|
||||
if RulesEngine.shouldStripCache() {
|
||||
head.headers.remove(name: "Expires")
|
||||
head.headers.remove(name: "Last-Modified")
|
||||
head.headers.remove(name: "ETag")
|
||||
head.headers.replaceOrAdd(name: "Expires", value: "0")
|
||||
head.headers.replaceOrAdd(name: "Cache-Control", value: "no-cache")
|
||||
}
|
||||
|
||||
responseHead = head
|
||||
responseBody = Data()
|
||||
ProxyLogger.capture.info("CAPTURE RESP \(head.status.code) for \(self.domain)")
|
||||
context.fireChannelRead(NIOAny(HTTPClientResponsePart.head(head)))
|
||||
return
|
||||
case .body(let buffer):
|
||||
if responseBody.count < ProxyConstants.maxBodySizeBytes {
|
||||
responseBody.append(contentsOf: buffer.readableBytesView)
|
||||
}
|
||||
ProxyLogger.capture.debug("CAPTURE RESP body chunk: \(buffer.readableBytes) bytes (total: \(self.responseBody.count))")
|
||||
case .end:
|
||||
ProxyLogger.capture.debug("CAPTURE RESP end — saving to DB")
|
||||
saveResponse()
|
||||
}
|
||||
|
||||
@@ -74,56 +103,79 @@ final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
// MARK: - Persistence
|
||||
|
||||
private func saveRequest() {
|
||||
guard let head = requestHead, let reqId = currentRequestId else { return }
|
||||
guard let head = requestHead, let reqId = currentRequestId else {
|
||||
ProxyLogger.capture.error("saveRequest: no head or requestId!")
|
||||
return
|
||||
}
|
||||
|
||||
let url = "\(scheme)://\(domain)\(head.uri)"
|
||||
let headersJSON = encodeHeaders(head.headers)
|
||||
let queryParams = extractQueryParams(from: head.uri)
|
||||
let shouldHide =
|
||||
(IPCManager.shared.hideSystemTraffic && SystemTrafficFilter.isSystemDomain(domain)) ||
|
||||
RulesEngine.checkBlockList(url: url, method: head.method.rawValue) == .hideOnly
|
||||
|
||||
let headerCount = head.headers.count
|
||||
let bodySize = requestBody.count
|
||||
|
||||
var traffic = CapturedTraffic(
|
||||
requestId: reqId,
|
||||
domain: domain,
|
||||
url: url,
|
||||
method: head.method.rawValue,
|
||||
scheme: scheme,
|
||||
requestId: reqId, domain: domain, url: url,
|
||||
method: head.method.rawValue, scheme: scheme,
|
||||
requestHeaders: headersJSON,
|
||||
requestBody: requestBody.isEmpty ? nil : requestBody,
|
||||
requestBodySize: requestBody.count,
|
||||
requestContentType: head.headers.first(name: "Content-Type"),
|
||||
queryParameters: queryParams,
|
||||
startedAt: requestStartTime,
|
||||
isSslDecrypted: scheme == "https"
|
||||
isSslDecrypted: scheme == "https",
|
||||
isHidden: shouldHide
|
||||
)
|
||||
|
||||
try? trafficRepo.insert(&traffic)
|
||||
do {
|
||||
try trafficRepo.insert(&traffic)
|
||||
ProxyLogger.capture.info("DB INSERT OK: \(head.method.rawValue) \(self.domain) headers=\(headerCount) body=\(bodySize)B id=\(reqId)")
|
||||
} catch {
|
||||
ProxyLogger.capture.error("DB INSERT FAILED: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
private func saveResponse() {
|
||||
guard let reqId = currentRequestId, let head = responseHead else { return }
|
||||
guard let reqId = currentRequestId, let head = responseHead else {
|
||||
ProxyLogger.capture.error("saveResponse: no requestId or responseHead!")
|
||||
return
|
||||
}
|
||||
|
||||
let now = Date().timeIntervalSince1970
|
||||
let durationMs = Int((now - requestStartTime) * 1000)
|
||||
let headerCount = head.headers.count
|
||||
let bodySize = responseBody.count
|
||||
|
||||
try? trafficRepo.updateResponse(
|
||||
requestId: reqId,
|
||||
statusCode: Int(head.status.code),
|
||||
statusText: head.status.reasonPhrase,
|
||||
responseHeaders: encodeHeaders(head.headers),
|
||||
responseBody: responseBody.isEmpty ? nil : responseBody,
|
||||
responseBodySize: responseBody.count,
|
||||
responseContentType: head.headers.first(name: "Content-Type"),
|
||||
completedAt: now,
|
||||
durationMs: durationMs
|
||||
)
|
||||
do {
|
||||
try trafficRepo.updateResponse(
|
||||
requestId: reqId,
|
||||
statusCode: Int(head.status.code),
|
||||
statusText: head.status.reasonPhrase,
|
||||
responseHeaders: encodeHeaders(head.headers),
|
||||
responseBody: responseBody.isEmpty ? nil : responseBody,
|
||||
responseBodySize: responseBody.count,
|
||||
responseContentType: head.headers.first(name: "Content-Type"),
|
||||
completedAt: now,
|
||||
durationMs: durationMs
|
||||
)
|
||||
ProxyLogger.capture.info("DB UPDATE OK: \(head.status.code) \(self.domain) headers=\(headerCount) body=\(bodySize)B duration=\(durationMs)ms id=\(reqId)")
|
||||
} catch {
|
||||
ProxyLogger.capture.error("DB UPDATE FAILED for \(reqId): \(error.localizedDescription)")
|
||||
}
|
||||
|
||||
IPCManager.shared.post(.newTrafficCaptured)
|
||||
logHardcodedBodyDebug(responseHead: head, requestId: reqId)
|
||||
|
||||
// Debounce — don't flood with notifications for every single response
|
||||
NotificationThrottle.shared.throttle()
|
||||
}
|
||||
|
||||
private func encodeHeaders(_ headers: HTTPHeaders) -> String? {
|
||||
var dict: [String: String] = [:]
|
||||
for (name, value) in headers {
|
||||
dict[name] = value
|
||||
}
|
||||
for (name, value) in headers { dict[name] = value }
|
||||
guard let data = try? JSONEncoder().encode(dict) else { return nil }
|
||||
return String(data: data, encoding: .utf8)
|
||||
}
|
||||
@@ -132,10 +184,59 @@ final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
guard let url = URLComponents(string: uri),
|
||||
let items = url.queryItems, !items.isEmpty else { return nil }
|
||||
var dict: [String: String] = [:]
|
||||
for item in items {
|
||||
dict[item.name] = item.value ?? ""
|
||||
}
|
||||
for item in items { dict[item.name] = item.value ?? "" }
|
||||
guard let data = try? JSONEncoder().encode(dict) else { return nil }
|
||||
return String(data: data, encoding: .utf8)
|
||||
}
|
||||
|
||||
private func logHardcodedBodyDebug(responseHead: HTTPResponseHead, requestId: String) {
|
||||
let responseHeaders = headerDictionary(from: responseHead.headers)
|
||||
let decodedBody = HTTPBodyDecoder.decodedBodyData(from: responseBody, headers: responseHeaders)
|
||||
let searchableBody = HTTPBodyDecoder.searchableText(from: responseBody, headers: responseHeaders) ?? ""
|
||||
let preview = decodedBodyPreview(headers: responseHeaders)
|
||||
|
||||
guard domain.localizedCaseInsensitiveContains(hardcodedDebugDomain) ||
|
||||
requestHead?.uri.localizedCaseInsensitiveContains(hardcodedDebugDomain) == true ||
|
||||
preview.localizedCaseInsensitiveContains(hardcodedDebugNeedle) else {
|
||||
return
|
||||
}
|
||||
|
||||
let contentType = responseHead.headers.first(name: "Content-Type") ?? "nil"
|
||||
let contentEncoding = responseHead.headers.first(name: "Content-Encoding") ?? "nil"
|
||||
let containsNeedle = searchableBody.localizedCaseInsensitiveContains(hardcodedDebugNeedle)
|
||||
let decodingHint = HTTPBodyDecoder.decodingHint(for: responseBody, headers: responseHeaders)
|
||||
|
||||
ProxyLogger.capture.info(
|
||||
"""
|
||||
HARDCODED DEBUG capture domain=\(self.domain) id=\(requestId) status=\(responseHead.status.code) \
|
||||
contentType=\(contentType) contentEncoding=\(contentEncoding) bodyBytes=\(self.responseBody.count) \
|
||||
decodedBytes=\(decodedBody?.count ?? 0) decoding=\(decodingHint) containsNeedle=\(containsNeedle)
|
||||
"""
|
||||
)
|
||||
|
||||
if containsNeedle {
|
||||
ProxyLogger.capture.info("HARDCODED DEBUG MATCH needle=\(self.hardcodedDebugNeedle) preview=\(preview)")
|
||||
} else {
|
||||
ProxyLogger.capture.info("HARDCODED DEBUG NO_MATCH needle=\(self.hardcodedDebugNeedle) preview=\(preview)")
|
||||
}
|
||||
}
|
||||
|
||||
private func decodedSearchableBody(headers: [String: String]) -> String {
|
||||
HTTPBodyDecoder.searchableText(from: responseBody, headers: headers) ?? ""
|
||||
}
|
||||
|
||||
private func decodedBodyPreview(headers: [String: String]) -> String {
|
||||
let raw = decodedSearchableBody(headers: headers)
|
||||
.replacingOccurrences(of: "\n", with: " ")
|
||||
.replacingOccurrences(of: "\r", with: " ")
|
||||
return String(raw.prefix(240))
|
||||
}
|
||||
|
||||
private func headerDictionary(from headers: HTTPHeaders) -> [String: String] {
|
||||
var dictionary: [String: String] = [:]
|
||||
for (name, value) in headers {
|
||||
dictionary[name] = value
|
||||
}
|
||||
return dictionary
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,110 +4,108 @@ import NIOPosix
|
||||
import NIOSSL
|
||||
import NIOHTTP1
|
||||
|
||||
/// After a CONNECT tunnel is established, this handler:
|
||||
/// 1. Reads the first bytes from the client to extract the SNI hostname from the TLS ClientHello
|
||||
/// 2. Generates a per-domain leaf certificate via CertificateManager
|
||||
/// 3. Terminates client-side TLS with the generated cert
|
||||
/// 4. Initiates server-side TLS to the real server
|
||||
/// 5. Installs HTTP codecs + HTTPCaptureHandler on both sides to capture decrypted traffic
|
||||
final class MITMHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
private let host: String
|
||||
private let originalHost: String
|
||||
private let upstreamHost: String
|
||||
private let port: Int
|
||||
private let trafficRepo: TrafficRepository
|
||||
private let certManager: CertificateManager
|
||||
private let runtimeStatusRepo = RuntimeStatusRepository()
|
||||
|
||||
init(host: String, port: Int, trafficRepo: TrafficRepository, certManager: CertificateManager = .shared) {
|
||||
self.host = host
|
||||
init(
|
||||
originalHost: String,
|
||||
upstreamHost: String,
|
||||
port: Int,
|
||||
trafficRepo: TrafficRepository,
|
||||
certManager: CertificateManager = .shared
|
||||
) {
|
||||
self.originalHost = originalHost
|
||||
self.upstreamHost = upstreamHost
|
||||
self.port = port
|
||||
self.trafficRepo = trafficRepo
|
||||
self.certManager = certManager
|
||||
ProxyLogger.mitm.info("MITMHandler created original=\(originalHost) upstream=\(upstreamHost):\(port)")
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
var buffer = unwrapInboundIn(data)
|
||||
let bufferSize = buffer.readableBytes
|
||||
|
||||
// Extract SNI from ClientHello if possible, otherwise use the CONNECT host
|
||||
let sniDomain = extractSNI(from: buffer) ?? host
|
||||
let sniDomain = extractSNI(from: buffer) ?? originalHost
|
||||
ProxyLogger.mitm.info("MITM ClientHello: \(bufferSize) bytes, SNI=\(sniDomain) (fallback host=\(self.originalHost))")
|
||||
|
||||
// Remove this handler — we'll rebuild the pipeline
|
||||
context.pipeline.removeHandler(self, promise: nil)
|
||||
|
||||
// Get TLS context for this domain
|
||||
let sslContext: NIOSSLContext
|
||||
do {
|
||||
sslContext = try certManager.tlsServerContext(for: sniDomain)
|
||||
ProxyLogger.mitm.info("MITM TLS context created for \(sniDomain)")
|
||||
} catch {
|
||||
print("[MITM] Failed to get TLS context for \(sniDomain): \(error)")
|
||||
ProxyLogger.mitm.error("MITM TLS context FAILED for \(sniDomain): \(error.localizedDescription)")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "TLS context \(sniDomain): \(error.localizedDescription)"
|
||||
}
|
||||
context.close(promise: nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Add server-side TLS handler (we are the "server" to the client)
|
||||
let sslServerHandler = NIOSSLServerHandler(context: sslContext)
|
||||
let trafficRepo = self.trafficRepo
|
||||
let host = self.host
|
||||
let originalHost = self.originalHost
|
||||
let upstreamHost = self.upstreamHost
|
||||
let port = self.port
|
||||
let runtimeStatusRepo = self.runtimeStatusRepo
|
||||
let tlsErrorHandler = TLSErrorLogger(label: "CLIENT-SIDE", domain: sniDomain, runtimeStatusRepo: runtimeStatusRepo)
|
||||
|
||||
context.channel.pipeline.addHandler(sslServerHandler, position: .first).flatMap {
|
||||
// Add HTTP codec after TLS
|
||||
// Add TLS error logger right after the SSL handler to catch handshake failures
|
||||
context.channel.pipeline.addHandler(tlsErrorHandler)
|
||||
}.flatMap {
|
||||
context.channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))
|
||||
}.flatMap {
|
||||
context.channel.pipeline.addHandler(HTTPResponseEncoder())
|
||||
}.flatMap {
|
||||
// Add the forwarding handler that connects to the real server
|
||||
context.channel.pipeline.addHandler(
|
||||
MITMForwardHandler(
|
||||
remoteHost: host,
|
||||
remoteHost: upstreamHost,
|
||||
remotePort: port,
|
||||
domain: sniDomain,
|
||||
originalDomain: originalHost,
|
||||
trafficRepo: trafficRepo
|
||||
)
|
||||
)
|
||||
}.whenComplete { result in
|
||||
switch result {
|
||||
case .success:
|
||||
// Re-fire the original ClientHello bytes so TLS handshake proceeds
|
||||
ProxyLogger.mitm.info("MITM pipeline installed for \(sniDomain), re-firing ClientHello (\(bufferSize) bytes)")
|
||||
context.channel.pipeline.fireChannelRead(NIOAny(buffer))
|
||||
case .failure(let error):
|
||||
print("[MITM] Pipeline setup failed: \(error)")
|
||||
ProxyLogger.mitm.error("MITM pipeline setup FAILED for \(sniDomain): \(error)")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "Pipeline setup \(sniDomain): \(error.localizedDescription)"
|
||||
}
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - SNI Extraction
|
||||
|
||||
/// Parse the SNI hostname from a TLS ClientHello message.
|
||||
private func extractSNI(from buffer: ByteBuffer) -> String? {
|
||||
var buf = buffer
|
||||
guard buf.readableBytes >= 43 else { return nil }
|
||||
|
||||
// TLS record header
|
||||
guard buf.readInteger(as: UInt8.self) == 0x16 else { return nil } // Handshake
|
||||
let _ = buf.readInteger(as: UInt16.self) // Version
|
||||
let _ = buf.readInteger(as: UInt16.self) // Length
|
||||
|
||||
// Handshake header
|
||||
guard buf.readInteger(as: UInt8.self) == 0x01 else { return nil } // ClientHello
|
||||
let _ = buf.readBytes(length: 3) // Length (3 bytes)
|
||||
|
||||
// Client version
|
||||
guard buf.readInteger(as: UInt8.self) == 0x16 else { return nil }
|
||||
let _ = buf.readInteger(as: UInt16.self)
|
||||
let _ = buf.readInteger(as: UInt16.self)
|
||||
guard buf.readInteger(as: UInt8.self) == 0x01 else { return nil }
|
||||
let _ = buf.readBytes(length: 3)
|
||||
let _ = buf.readInteger(as: UInt16.self)
|
||||
// Random (32 bytes)
|
||||
guard buf.readBytes(length: 32) != nil else { return nil }
|
||||
// Session ID
|
||||
guard let sessionIdLen = buf.readInteger(as: UInt8.self) else { return nil }
|
||||
guard buf.readBytes(length: Int(sessionIdLen)) != nil else { return nil }
|
||||
// Cipher suites
|
||||
guard let cipherSuitesLen = buf.readInteger(as: UInt16.self) else { return nil }
|
||||
guard buf.readBytes(length: Int(cipherSuitesLen)) != nil else { return nil }
|
||||
// Compression methods
|
||||
guard let compMethodsLen = buf.readInteger(as: UInt8.self) else { return nil }
|
||||
guard buf.readBytes(length: Int(compMethodsLen)) != nil else { return nil }
|
||||
|
||||
// Extensions
|
||||
guard let extensionsLen = buf.readInteger(as: UInt16.self) else { return nil }
|
||||
var extensionsRemaining = Int(extensionsLen)
|
||||
|
||||
@@ -116,50 +114,48 @@ final class MITMHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
let extLen = buf.readInteger(as: UInt16.self) else { return nil }
|
||||
extensionsRemaining -= 4 + Int(extLen)
|
||||
|
||||
if extType == 0x0000 { // SNI extension
|
||||
guard let _ = buf.readInteger(as: UInt16.self), // SNI list length
|
||||
if extType == 0x0000 {
|
||||
guard let _ = buf.readInteger(as: UInt16.self),
|
||||
let nameType = buf.readInteger(as: UInt8.self),
|
||||
nameType == 0x00, // hostname
|
||||
nameType == 0x00,
|
||||
let nameLen = buf.readInteger(as: UInt16.self),
|
||||
let nameBytes = buf.readBytes(length: Int(nameLen)) else {
|
||||
return nil
|
||||
}
|
||||
return String(bytes: nameBytes, encoding: .utf8)
|
||||
let nameBytes = buf.readBytes(length: Int(nameLen)) else { return nil }
|
||||
let name = String(bytes: nameBytes, encoding: .utf8)
|
||||
ProxyLogger.mitm.debug("SNI extracted: \(name ?? "nil")")
|
||||
return name
|
||||
} else {
|
||||
guard buf.readBytes(length: Int(extLen)) != nil else { return nil }
|
||||
}
|
||||
}
|
||||
|
||||
ProxyLogger.mitm.debug("SNI: not found in ClientHello")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - MITMForwardHandler
|
||||
|
||||
/// Handles decrypted HTTP from the client, forwards to the real server over TLS,
|
||||
/// and relays responses back. Captures everything via HTTPCaptureHandler.
|
||||
final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPServerRequestPart
|
||||
typealias OutboundOut = HTTPServerResponsePart
|
||||
|
||||
private let remoteHost: String
|
||||
private let remotePort: Int
|
||||
private let domain: String
|
||||
private let originalDomain: String
|
||||
private let trafficRepo: TrafficRepository
|
||||
private let runtimeStatusRepo = RuntimeStatusRepository()
|
||||
private var remoteChannel: Channel?
|
||||
|
||||
// Buffer request parts until upstream is connected
|
||||
private var pendingParts: [HTTPServerRequestPart] = []
|
||||
private var isConnected = false
|
||||
|
||||
init(remoteHost: String, remotePort: Int, domain: String, trafficRepo: TrafficRepository) {
|
||||
init(remoteHost: String, remotePort: Int, originalDomain: String, trafficRepo: TrafficRepository) {
|
||||
self.remoteHost = remoteHost
|
||||
self.remotePort = remotePort
|
||||
self.domain = domain
|
||||
self.originalDomain = originalDomain
|
||||
self.trafficRepo = trafficRepo
|
||||
}
|
||||
|
||||
func handlerAdded(context: ChannelHandlerContext) {
|
||||
ProxyLogger.mitm.info("MITMForward: connecting to upstream \(self.remoteHost):\(self.remotePort)")
|
||||
connectToRemote(context: context)
|
||||
}
|
||||
|
||||
@@ -167,12 +163,17 @@ final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
if isConnected, let remote = remoteChannel {
|
||||
// Forward to upstream as client request
|
||||
switch part {
|
||||
case .head(let head):
|
||||
ProxyLogger.mitm.info("MITMForward: decrypted request \(head.method.rawValue) \(head.uri)")
|
||||
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
|
||||
if !clientHead.headers.contains(name: "Host") {
|
||||
clientHead.headers.add(name: "Host", value: domain)
|
||||
clientHead.headers.add(name: "Host", value: originalDomain)
|
||||
}
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastSuccessfulMITMDomain = self.originalDomain
|
||||
$0.lastSuccessfulMITMAt = Date().timeIntervalSince1970
|
||||
$0.lastMITMError = nil
|
||||
}
|
||||
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
|
||||
case .body(let buffer):
|
||||
@@ -181,22 +182,27 @@ final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
|
||||
}
|
||||
} else {
|
||||
ProxyLogger.mitm.debug("MITMForward: buffering request part (not connected yet)")
|
||||
pendingParts.append(part)
|
||||
}
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
ProxyLogger.mitm.debug("MITMForward: client channel inactive")
|
||||
remoteChannel?.close(promise: nil)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
print("[MITMForward] Error: \(error)")
|
||||
ProxyLogger.mitm.error("MITMForward error: \(error.localizedDescription)")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "Forwarding \(self.originalDomain): \(error.localizedDescription)"
|
||||
}
|
||||
context.close(promise: nil)
|
||||
remoteChannel?.close(promise: nil)
|
||||
}
|
||||
|
||||
private func connectToRemote(context: ChannelHandlerContext) {
|
||||
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: domain, scheme: "https")
|
||||
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: originalDomain, scheme: "https")
|
||||
let clientContext = context
|
||||
|
||||
do {
|
||||
@@ -206,44 +212,63 @@ final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
let sniHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.domain)
|
||||
let sniHandler: NIOSSLClientHandler
|
||||
do {
|
||||
sniHandler = try NIOSSLClientHandler(context: sslContext, serverHostname: self.originalDomain)
|
||||
} catch {
|
||||
ProxyLogger.mitm.error("NIOSSLClientHandler init FAILED: \(error.localizedDescription)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "Client TLS handler \(self.originalDomain): \(error.localizedDescription)"
|
||||
}
|
||||
channel.close(promise: nil)
|
||||
return channel.eventLoop.makeFailedFuture(error)
|
||||
}
|
||||
let upstreamTLSLogger = TLSErrorLogger(label: "UPSTREAM", domain: self.originalDomain, runtimeStatusRepo: self.runtimeStatusRepo)
|
||||
return channel.pipeline.addHandler(sniHandler).flatMap {
|
||||
channel.pipeline.addHandler(upstreamTLSLogger)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(HTTPRequestEncoder())
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(captureHandler)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(
|
||||
MITMRelayHandler(clientContext: clientContext)
|
||||
)
|
||||
channel.pipeline.addHandler(MITMRelayHandler(clientContext: clientContext))
|
||||
}
|
||||
}
|
||||
.connect(host: remoteHost, port: remotePort)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let channel):
|
||||
ProxyLogger.mitm.info("MITMForward: upstream connected to \(self.remoteHost):\(self.remotePort)")
|
||||
self.remoteChannel = channel
|
||||
self.isConnected = true
|
||||
self.flushPending(remote: channel)
|
||||
case .failure(let error):
|
||||
print("[MITMForward] Connect to \(self.remoteHost):\(self.remotePort) failed: \(error)")
|
||||
ProxyLogger.mitm.error("MITMForward: upstream connect FAILED \(self.remoteHost):\(self.remotePort): \(error.localizedDescription)")
|
||||
self.runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "Upstream \(self.originalDomain): \(error.localizedDescription)"
|
||||
}
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[MITMForward] TLS setup failed: \(error)")
|
||||
ProxyLogger.mitm.error("MITMForward: TLS context creation FAILED: \(error.localizedDescription)")
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "TLS configuration \(self.originalDomain): \(error.localizedDescription)"
|
||||
}
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
private func flushPending(remote: Channel) {
|
||||
ProxyLogger.mitm.debug("MITMForward: flushing \(self.pendingParts.count) buffered parts")
|
||||
for part in pendingParts {
|
||||
switch part {
|
||||
case .head(let head):
|
||||
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
|
||||
if !clientHead.headers.contains(name: "Host") {
|
||||
clientHead.headers.add(name: "Host", value: domain)
|
||||
clientHead.headers.add(name: "Host", value: originalDomain)
|
||||
}
|
||||
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
|
||||
case .body(let buffer):
|
||||
@@ -258,7 +283,6 @@ final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
|
||||
// MARK: - MITMRelayHandler
|
||||
|
||||
/// Relays responses from the real server back to the proxy client.
|
||||
final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPClientResponsePart
|
||||
|
||||
@@ -270,11 +294,10 @@ final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
let serverResponse = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
|
||||
clientContext.write(NIOAny(HTTPServerResponsePart.head(serverResponse)), promise: nil)
|
||||
ProxyLogger.mitm.debug("MITMRelay response: \(head.status.code)")
|
||||
clientContext.write(NIOAny(HTTPServerResponsePart.head(HTTPResponseHead(version: head.version, status: head.status, headers: head.headers))), promise: nil)
|
||||
case .body(let buffer):
|
||||
clientContext.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil)
|
||||
case .end(let trailers):
|
||||
@@ -283,12 +306,105 @@ final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
ProxyLogger.mitm.debug("MITMRelay: remote inactive")
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
print("[MITMRelay] Error: \(error)")
|
||||
ProxyLogger.mitm.error("MITMRelay error: \(error.localizedDescription)")
|
||||
RuntimeStatusRepository().update {
|
||||
$0.lastMITMError = "Relay response: \(error.localizedDescription)"
|
||||
}
|
||||
context.close(promise: nil)
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - TLSErrorLogger
|
||||
|
||||
/// Catches and logs TLS handshake errors with detailed context.
|
||||
/// Placed right after NIOSSLServerHandler/NIOSSLClientHandler in the pipeline.
|
||||
final class TLSErrorLogger: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = NIOAny
|
||||
|
||||
private let label: String
|
||||
private let domain: String
|
||||
private let runtimeStatusRepo: RuntimeStatusRepository
|
||||
|
||||
init(label: String, domain: String, runtimeStatusRepo: RuntimeStatusRepository) {
|
||||
self.label = label
|
||||
self.domain = domain
|
||||
self.runtimeStatusRepo = runtimeStatusRepo
|
||||
}
|
||||
|
||||
func channelActive(context: ChannelHandlerContext) {
|
||||
ProxyLogger.mitm.info("TLS[\(self.label)] \(self.domain): channel active (handshake starting)")
|
||||
context.fireChannelActive()
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
ProxyLogger.mitm.info("TLS[\(self.label)] \(self.domain): channel inactive")
|
||||
context.fireChannelInactive()
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
// TLS handshake completed if we're getting data through
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
|
||||
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
||||
if let tlsEvent = event as? NIOSSLVerificationCallback {
|
||||
ProxyLogger.mitm.info("TLS[\(self.label)] \(self.domain): verification callback triggered")
|
||||
}
|
||||
// Check for handshake completion by string matching the event type
|
||||
let eventDesc = String(describing: event)
|
||||
if eventDesc.contains("handshakeCompleted") {
|
||||
ProxyLogger.mitm.info("TLS[\(self.label)] \(self.domain): HANDSHAKE COMPLETED event=\(eventDesc)")
|
||||
} else {
|
||||
ProxyLogger.mitm.debug("TLS[\(self.label)] \(self.domain): user event=\(eventDesc)")
|
||||
}
|
||||
context.fireUserInboundEventTriggered(event)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
let errorDesc = String(describing: error)
|
||||
ProxyLogger.mitm.error("TLS[\(self.label)] \(self.domain): ERROR \(errorDesc)")
|
||||
|
||||
// Categorize and detect SSL pinning
|
||||
let lowerError = errorDesc.lowercased()
|
||||
var isPinningLikely = false
|
||||
var category = "UNKNOWN"
|
||||
|
||||
if lowerError.contains("certificate") || lowerError.contains("trust") {
|
||||
category = "CERTIFICATE_TRUST"
|
||||
isPinningLikely = label == "CLIENT-SIDE"
|
||||
ProxyLogger.mitm.error("TLS[\(self.label)] \(self.domain): CERTIFICATE TRUST ISSUE — client likely doesn't trust our CA")
|
||||
} else if lowerError.contains("handshake") {
|
||||
category = "HANDSHAKE_FAILURE"
|
||||
isPinningLikely = label == "CLIENT-SIDE"
|
||||
ProxyLogger.mitm.error("TLS[\(self.label)] \(self.domain): HANDSHAKE FAILURE — protocol mismatch or cert rejected")
|
||||
} else if lowerError.contains("eof") || lowerError.contains("reset") || lowerError.contains("closed") || lowerError.contains("connection") {
|
||||
category = "CONNECTION_RESET"
|
||||
isPinningLikely = label == "CLIENT-SIDE"
|
||||
ProxyLogger.mitm.error("TLS[\(self.label)] \(self.domain): CONNECTION RESET during handshake (SSL pinning suspected)")
|
||||
} else if lowerError.contains("unrecognized") || lowerError.contains("alert") || lowerError.contains("fatal") {
|
||||
category = "TLS_ALERT"
|
||||
isPinningLikely = true
|
||||
ProxyLogger.mitm.error("TLS[\(self.label)] \(self.domain): TLS ALERT — peer sent alert (unknown_ca / bad_certificate)")
|
||||
}
|
||||
|
||||
// If this is a client-side error (the app rejected our cert), it's likely SSL pinning.
|
||||
// Auto-record this domain as pinned so future connections use passthrough.
|
||||
if isPinningLikely && label == "CLIENT-SIDE" {
|
||||
let reason = "TLS \(category): \(String(errorDesc.prefix(200)))"
|
||||
PinnedDomainRepository().markPinned(domain: domain, reason: reason)
|
||||
ProxyLogger.mitm.error("TLS[\(self.label)] \(self.domain): AUTO-PINNED — future connections will use passthrough")
|
||||
}
|
||||
|
||||
runtimeStatusRepo.update {
|
||||
$0.lastMITMError = "TLS[\(self.label)] \(self.domain) [\(category)]: \(String(errorDesc.prefix(200)))"
|
||||
}
|
||||
|
||||
context.fireErrorCaught(error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,19 +17,21 @@ public final class ProxyServer: Sendable {
|
||||
) {
|
||||
self.host = host
|
||||
self.port = port
|
||||
// Use only 1 thread to conserve memory in the extension (50MB budget)
|
||||
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
self.trafficRepo = trafficRepo
|
||||
ProxyLogger.proxy.info("ProxyServer init: \(host):\(port)")
|
||||
}
|
||||
|
||||
public func start() async throws {
|
||||
let trafficRepo = self.trafficRepo
|
||||
|
||||
ProxyLogger.proxy.info("ProxyServer binding to \(self.host):\(self.port)...")
|
||||
let bootstrap = ServerBootstrap(group: group)
|
||||
.serverChannelOption(.backlog, value: 256)
|
||||
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.childChannelInitializer { channel in
|
||||
channel.pipeline.addHandler(
|
||||
ProxyLogger.proxy.debug("New client connection from \(channel.remoteAddress?.description ?? "unknown")")
|
||||
return channel.pipeline.addHandler(
|
||||
ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
|
||||
).flatMap {
|
||||
channel.pipeline.addHandler(HTTPResponseEncoder())
|
||||
@@ -41,15 +43,17 @@ public final class ProxyServer: Sendable {
|
||||
.childChannelOption(.maxMessagesPerRead, value: 16)
|
||||
|
||||
channel = try await bootstrap.bind(host: host, port: port).get()
|
||||
print("[ProxyServer] Listening on \(host):\(port)")
|
||||
ProxyLogger.proxy.info("ProxyServer LISTENING on \(self.host):\(self.port)")
|
||||
}
|
||||
|
||||
public func stop() async {
|
||||
ProxyLogger.proxy.info("ProxyServer stopping...")
|
||||
do {
|
||||
try await channel?.close()
|
||||
try await group.shutdownGracefully()
|
||||
ProxyLogger.proxy.info("ProxyServer stopped cleanly")
|
||||
} catch {
|
||||
print("[ProxyServer] Shutdown error: \(error)")
|
||||
ProxyLogger.proxy.error("ProxyServer shutdown error: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
82
ProxyCore/Sources/ProxyEngine/RulesEngine.swift
Normal file
82
ProxyCore/Sources/ProxyEngine/RulesEngine.swift
Normal file
@@ -0,0 +1,82 @@
|
||||
import Foundation
|
||||
|
||||
/// Centralized rules engine that checks proxy rules (block list, map local, DNS spoofing, no-cache)
|
||||
/// against live traffic. All methods are static and synchronous for use in NIO pipeline handlers.
|
||||
public enum RulesEngine {
|
||||
|
||||
private static let rulesRepo = RulesRepository()
|
||||
|
||||
// MARK: - Block List
|
||||
|
||||
/// Returns the `BlockAction` if the given URL + method matches an enabled block rule, or nil.
|
||||
public static func checkBlockList(url: String, method: String) -> BlockAction? {
|
||||
guard IPCManager.shared.isBlockListEnabled else { return nil }
|
||||
do {
|
||||
let entries = try rulesRepo.fetchEnabledBlockEntries()
|
||||
for entry in entries {
|
||||
guard entry.method == "ANY" || entry.method == method else { continue }
|
||||
if blockEntry(entry, matches: url) {
|
||||
return entry.action
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[RulesEngine] Failed to check block list: \(error)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - Map Local
|
||||
|
||||
/// Returns the first matching `MapLocalRule` for the URL + method, or nil.
|
||||
public static func checkMapLocal(url: String, method: String) -> MapLocalRule? {
|
||||
do {
|
||||
let rules = try rulesRepo.fetchEnabledMapLocalRules()
|
||||
for rule in rules {
|
||||
guard rule.method == "ANY" || rule.method == method else { continue }
|
||||
if WildcardMatcher.matches(url, pattern: rule.urlPattern) {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[RulesEngine] Failed to check map local rules: \(error)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - DNS Spoofing
|
||||
|
||||
/// Returns the target domain if the given domain matches an enabled DNS spoof rule, or nil.
|
||||
public static func checkDNSSpoof(domain: String) -> String? {
|
||||
guard IPCManager.shared.isDNSSpoofingEnabled else { return nil }
|
||||
do {
|
||||
let rules = try rulesRepo.fetchEnabledDNSSpoofRules()
|
||||
for rule in rules {
|
||||
if WildcardMatcher.matches(domain, pattern: rule.sourceDomain) {
|
||||
return rule.targetDomain
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[RulesEngine] Failed to check DNS spoof rules: \(error)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - No-Cache
|
||||
|
||||
/// Returns true if the no-caching toggle is enabled.
|
||||
public static func shouldStripCache() -> Bool {
|
||||
IPCManager.shared.isNoCachingEnabled
|
||||
}
|
||||
|
||||
private static func blockEntry(_ entry: BlockListEntry, matches url: String) -> Bool {
|
||||
if WildcardMatcher.matches(url, pattern: entry.urlPattern) {
|
||||
return true
|
||||
}
|
||||
|
||||
guard entry.includeSubpaths else { return false }
|
||||
guard !entry.urlPattern.contains("*"), !entry.urlPattern.contains("?") else { return false }
|
||||
|
||||
let normalizedPattern = entry.urlPattern.hasSuffix("/") ? entry.urlPattern : "\(entry.urlPattern)/"
|
||||
return url == entry.urlPattern || url.hasPrefix(normalizedPattern)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user