Initial project setup - Phases 1-3 complete
This commit is contained in:
134
ProxyCore/Sources/DataLayer/Database/DatabaseManager.swift
Normal file
134
ProxyCore/Sources/DataLayer/Database/DatabaseManager.swift
Normal file
@@ -0,0 +1,134 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public final class DatabaseManager: Sendable {
|
||||
public let dbPool: DatabasePool
|
||||
|
||||
public static let shared: DatabaseManager = {
|
||||
let url = FileManager.default
|
||||
.containerURL(forSecurityApplicationGroupIdentifier: "group.com.treyt.proxyapp")!
|
||||
.appendingPathComponent("proxy.sqlite")
|
||||
return try! DatabaseManager(path: url.path)
|
||||
}()
|
||||
|
||||
public init(path: String) throws {
|
||||
var config = Configuration()
|
||||
config.prepareDatabase { db in
|
||||
// WAL mode for cross-process concurrent access
|
||||
try db.execute(sql: "PRAGMA journal_mode = WAL")
|
||||
try db.execute(sql: "PRAGMA synchronous = NORMAL")
|
||||
}
|
||||
dbPool = try DatabasePool(path: path, configuration: config)
|
||||
try migrate()
|
||||
}
|
||||
|
||||
private func migrate() throws {
|
||||
var migrator = DatabaseMigrator()
|
||||
|
||||
migrator.registerMigration("v1_create_tables") { db in
|
||||
try db.create(table: "captured_traffic") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("requestId", .text).notNull().unique()
|
||||
t.column("domain", .text).notNull().indexed()
|
||||
t.column("url", .text).notNull()
|
||||
t.column("method", .text).notNull()
|
||||
t.column("scheme", .text).notNull()
|
||||
t.column("statusCode", .integer)
|
||||
t.column("statusText", .text)
|
||||
|
||||
t.column("requestHeaders", .text)
|
||||
t.column("requestBody", .blob)
|
||||
t.column("requestBodySize", .integer).notNull().defaults(to: 0)
|
||||
t.column("requestContentType", .text)
|
||||
t.column("queryParameters", .text)
|
||||
|
||||
t.column("responseHeaders", .text)
|
||||
t.column("responseBody", .blob)
|
||||
t.column("responseBodySize", .integer).notNull().defaults(to: 0)
|
||||
t.column("responseContentType", .text)
|
||||
|
||||
t.column("startedAt", .double).notNull()
|
||||
t.column("completedAt", .double)
|
||||
t.column("durationMs", .integer)
|
||||
|
||||
t.column("isSslDecrypted", .boolean).notNull().defaults(to: false)
|
||||
t.column("isPinned", .boolean).notNull().defaults(to: false)
|
||||
t.column("isWebsocket", .boolean).notNull().defaults(to: false)
|
||||
t.column("isHidden", .boolean).notNull().defaults(to: false)
|
||||
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
|
||||
try db.create(index: "idx_traffic_started_at", on: "captured_traffic", columns: ["startedAt"])
|
||||
try db.create(index: "idx_traffic_pinned", on: "captured_traffic", columns: ["isPinned"])
|
||||
|
||||
try db.create(table: "ssl_proxying_entries") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("domainPattern", .text).notNull()
|
||||
t.column("isInclude", .boolean).notNull()
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
|
||||
try db.create(table: "block_list_entries") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("name", .text)
|
||||
t.column("urlPattern", .text).notNull()
|
||||
t.column("method", .text).notNull().defaults(to: "ANY")
|
||||
t.column("includeSubpaths", .boolean).notNull().defaults(to: true)
|
||||
t.column("blockAction", .text).notNull().defaults(to: "block_and_hide")
|
||||
t.column("isEnabled", .boolean).notNull().defaults(to: true)
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
|
||||
try db.create(table: "breakpoint_rules") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("name", .text)
|
||||
t.column("urlPattern", .text).notNull()
|
||||
t.column("method", .text).notNull().defaults(to: "ANY")
|
||||
t.column("interceptRequest", .boolean).notNull().defaults(to: true)
|
||||
t.column("interceptResponse", .boolean).notNull().defaults(to: true)
|
||||
t.column("isEnabled", .boolean).notNull().defaults(to: true)
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
|
||||
try db.create(table: "map_local_rules") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("name", .text)
|
||||
t.column("urlPattern", .text).notNull()
|
||||
t.column("method", .text).notNull().defaults(to: "ANY")
|
||||
t.column("responseStatus", .integer).notNull().defaults(to: 200)
|
||||
t.column("responseHeaders", .text)
|
||||
t.column("responseBody", .blob)
|
||||
t.column("responseContentType", .text)
|
||||
t.column("isEnabled", .boolean).notNull().defaults(to: true)
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
|
||||
try db.create(table: "dns_spoof_rules") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("sourceDomain", .text).notNull()
|
||||
t.column("targetDomain", .text).notNull()
|
||||
t.column("isEnabled", .boolean).notNull().defaults(to: true)
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
|
||||
try db.create(table: "compose_requests") { t in
|
||||
t.autoIncrementedPrimaryKey("id")
|
||||
t.column("name", .text).notNull().defaults(to: "New Request")
|
||||
t.column("method", .text).notNull().defaults(to: "GET")
|
||||
t.column("url", .text)
|
||||
t.column("headers", .text)
|
||||
t.column("queryParameters", .text)
|
||||
t.column("body", .text)
|
||||
t.column("bodyContentType", .text)
|
||||
t.column("responseStatus", .integer)
|
||||
t.column("responseHeaders", .text)
|
||||
t.column("responseBody", .blob)
|
||||
t.column("lastSentAt", .double)
|
||||
t.column("createdAt", .double).notNull()
|
||||
}
|
||||
}
|
||||
|
||||
try migrator.migrate(dbPool)
|
||||
}
|
||||
}
|
||||
57
ProxyCore/Sources/DataLayer/Models/BlockListEntry.swift
Normal file
57
ProxyCore/Sources/DataLayer/Models/BlockListEntry.swift
Normal file
@@ -0,0 +1,57 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public enum BlockAction: String, Codable, Sendable, CaseIterable {
|
||||
case blockAndHide = "block_and_hide"
|
||||
case blockAndDisplay = "block_and_display"
|
||||
case hideOnly = "hide_only"
|
||||
|
||||
public var displayName: String {
|
||||
switch self {
|
||||
case .blockAndHide: "Block & Hide Request"
|
||||
case .blockAndDisplay: "Block & Display"
|
||||
case .hideOnly: "Hide but not Block"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public struct BlockListEntry: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var name: String?
|
||||
public var urlPattern: String
|
||||
public var method: String
|
||||
public var includeSubpaths: Bool
|
||||
public var blockAction: String
|
||||
public var isEnabled: Bool
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "block_list_entries"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(
|
||||
id: Int64? = nil,
|
||||
name: String? = nil,
|
||||
urlPattern: String,
|
||||
method: String = "ANY",
|
||||
includeSubpaths: Bool = true,
|
||||
blockAction: BlockAction = .blockAndHide,
|
||||
isEnabled: Bool = true,
|
||||
createdAt: Double = Date().timeIntervalSince1970
|
||||
) {
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.urlPattern = urlPattern
|
||||
self.method = method
|
||||
self.includeSubpaths = includeSubpaths
|
||||
self.blockAction = blockAction.rawValue
|
||||
self.isEnabled = isEnabled
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
|
||||
public var action: BlockAction {
|
||||
BlockAction(rawValue: blockAction) ?? .blockAndHide
|
||||
}
|
||||
}
|
||||
39
ProxyCore/Sources/DataLayer/Models/BreakpointRule.swift
Normal file
39
ProxyCore/Sources/DataLayer/Models/BreakpointRule.swift
Normal file
@@ -0,0 +1,39 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct BreakpointRule: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var name: String?
|
||||
public var urlPattern: String
|
||||
public var method: String
|
||||
public var interceptRequest: Bool
|
||||
public var interceptResponse: Bool
|
||||
public var isEnabled: Bool
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "breakpoint_rules"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(
|
||||
id: Int64? = nil,
|
||||
name: String? = nil,
|
||||
urlPattern: String,
|
||||
method: String = "ANY",
|
||||
interceptRequest: Bool = true,
|
||||
interceptResponse: Bool = true,
|
||||
isEnabled: Bool = true,
|
||||
createdAt: Double = Date().timeIntervalSince1970
|
||||
) {
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.urlPattern = urlPattern
|
||||
self.method = method
|
||||
self.interceptRequest = interceptRequest
|
||||
self.interceptResponse = interceptResponse
|
||||
self.isEnabled = isEnabled
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
140
ProxyCore/Sources/DataLayer/Models/CapturedTraffic.swift
Normal file
140
ProxyCore/Sources/DataLayer/Models/CapturedTraffic.swift
Normal file
@@ -0,0 +1,140 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct CapturedTraffic: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var requestId: String
|
||||
public var domain: String
|
||||
public var url: String
|
||||
public var method: String
|
||||
public var scheme: String
|
||||
public var statusCode: Int?
|
||||
public var statusText: String?
|
||||
|
||||
// Request
|
||||
public var requestHeaders: String?
|
||||
public var requestBody: Data?
|
||||
public var requestBodySize: Int
|
||||
public var requestContentType: String?
|
||||
public var queryParameters: String?
|
||||
|
||||
// Response
|
||||
public var responseHeaders: String?
|
||||
public var responseBody: Data?
|
||||
public var responseBodySize: Int
|
||||
public var responseContentType: String?
|
||||
|
||||
// Timing
|
||||
public var startedAt: Double
|
||||
public var completedAt: Double?
|
||||
public var durationMs: Int?
|
||||
|
||||
// Metadata
|
||||
public var isSslDecrypted: Bool
|
||||
public var isPinned: Bool
|
||||
public var isWebsocket: Bool
|
||||
public var isHidden: Bool
|
||||
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "captured_traffic"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(
|
||||
id: Int64? = nil,
|
||||
requestId: String = UUID().uuidString,
|
||||
domain: String,
|
||||
url: String,
|
||||
method: String,
|
||||
scheme: String,
|
||||
statusCode: Int? = nil,
|
||||
statusText: String? = nil,
|
||||
requestHeaders: String? = nil,
|
||||
requestBody: Data? = nil,
|
||||
requestBodySize: Int = 0,
|
||||
requestContentType: String? = nil,
|
||||
queryParameters: String? = nil,
|
||||
responseHeaders: String? = nil,
|
||||
responseBody: Data? = nil,
|
||||
responseBodySize: Int = 0,
|
||||
responseContentType: String? = nil,
|
||||
startedAt: Double = Date().timeIntervalSince1970,
|
||||
completedAt: Double? = nil,
|
||||
durationMs: Int? = nil,
|
||||
isSslDecrypted: Bool = false,
|
||||
isPinned: Bool = false,
|
||||
isWebsocket: Bool = false,
|
||||
isHidden: Bool = false,
|
||||
createdAt: Double = Date().timeIntervalSince1970
|
||||
) {
|
||||
self.id = id
|
||||
self.requestId = requestId
|
||||
self.domain = domain
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.scheme = scheme
|
||||
self.statusCode = statusCode
|
||||
self.statusText = statusText
|
||||
self.requestHeaders = requestHeaders
|
||||
self.requestBody = requestBody
|
||||
self.requestBodySize = requestBodySize
|
||||
self.requestContentType = requestContentType
|
||||
self.queryParameters = queryParameters
|
||||
self.responseHeaders = responseHeaders
|
||||
self.responseBody = responseBody
|
||||
self.responseBodySize = responseBodySize
|
||||
self.responseContentType = responseContentType
|
||||
self.startedAt = startedAt
|
||||
self.completedAt = completedAt
|
||||
self.durationMs = durationMs
|
||||
self.isSslDecrypted = isSslDecrypted
|
||||
self.isPinned = isPinned
|
||||
self.isWebsocket = isWebsocket
|
||||
self.isHidden = isHidden
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Computed Properties
|
||||
|
||||
extension CapturedTraffic {
|
||||
public var decodedRequestHeaders: [String: String] {
|
||||
guard let data = requestHeaders?.data(using: .utf8),
|
||||
let dict = try? JSONDecoder().decode([String: String].self, from: data) else {
|
||||
return [:]
|
||||
}
|
||||
return dict
|
||||
}
|
||||
|
||||
public var decodedResponseHeaders: [String: String] {
|
||||
guard let data = responseHeaders?.data(using: .utf8),
|
||||
let dict = try? JSONDecoder().decode([String: String].self, from: data) else {
|
||||
return [:]
|
||||
}
|
||||
return dict
|
||||
}
|
||||
|
||||
public var decodedQueryParameters: [String: String] {
|
||||
guard let data = queryParameters?.data(using: .utf8),
|
||||
let dict = try? JSONDecoder().decode([String: String].self, from: data) else {
|
||||
return [:]
|
||||
}
|
||||
return dict
|
||||
}
|
||||
|
||||
public var startDate: Date {
|
||||
Date(timeIntervalSince1970: startedAt)
|
||||
}
|
||||
|
||||
public var formattedDuration: String {
|
||||
guard let ms = durationMs else { return "-" }
|
||||
if ms < 1000 {
|
||||
return "\(ms) ms"
|
||||
} else {
|
||||
return String(format: "%.1f s", Double(ms) / 1000.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
54
ProxyCore/Sources/DataLayer/Models/ComposeRequest.swift
Normal file
54
ProxyCore/Sources/DataLayer/Models/ComposeRequest.swift
Normal file
@@ -0,0 +1,54 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct ComposeRequest: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var name: String
|
||||
public var method: String
|
||||
public var url: String?
|
||||
public var headers: String?
|
||||
public var queryParameters: String?
|
||||
public var body: String?
|
||||
public var bodyContentType: String?
|
||||
public var responseStatus: Int?
|
||||
public var responseHeaders: String?
|
||||
public var responseBody: Data?
|
||||
public var lastSentAt: Double?
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "compose_requests"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(
|
||||
id: Int64? = nil,
|
||||
name: String = "New Request",
|
||||
method: String = "GET",
|
||||
url: String? = nil,
|
||||
headers: String? = nil,
|
||||
queryParameters: String? = nil,
|
||||
body: String? = nil,
|
||||
bodyContentType: String? = nil,
|
||||
responseStatus: Int? = nil,
|
||||
responseHeaders: String? = nil,
|
||||
responseBody: Data? = nil,
|
||||
lastSentAt: Double? = nil,
|
||||
createdAt: Double = Date().timeIntervalSince1970
|
||||
) {
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.method = method
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
self.queryParameters = queryParameters
|
||||
self.body = body
|
||||
self.bodyContentType = bodyContentType
|
||||
self.responseStatus = responseStatus
|
||||
self.responseHeaders = responseHeaders
|
||||
self.responseBody = responseBody
|
||||
self.lastSentAt = lastSentAt
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
30
ProxyCore/Sources/DataLayer/Models/DNSSpoofRule.swift
Normal file
30
ProxyCore/Sources/DataLayer/Models/DNSSpoofRule.swift
Normal file
@@ -0,0 +1,30 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct DNSSpoofRule: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var sourceDomain: String
|
||||
public var targetDomain: String
|
||||
public var isEnabled: Bool
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "dns_spoof_rules"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(
|
||||
id: Int64? = nil,
|
||||
sourceDomain: String,
|
||||
targetDomain: String,
|
||||
isEnabled: Bool = true,
|
||||
createdAt: Double = Date().timeIntervalSince1970
|
||||
) {
|
||||
self.id = id
|
||||
self.sourceDomain = sourceDomain
|
||||
self.targetDomain = targetDomain
|
||||
self.isEnabled = isEnabled
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
13
ProxyCore/Sources/DataLayer/Models/DomainGroup.swift
Normal file
13
ProxyCore/Sources/DataLayer/Models/DomainGroup.swift
Normal file
@@ -0,0 +1,13 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct DomainGroup: Decodable, FetchableRecord, Identifiable, Hashable, Sendable {
|
||||
public var id: String { domain }
|
||||
public var domain: String
|
||||
public var requestCount: Int
|
||||
|
||||
public init(domain: String, requestCount: Int) {
|
||||
self.domain = domain
|
||||
self.requestCount = requestCount
|
||||
}
|
||||
}
|
||||
45
ProxyCore/Sources/DataLayer/Models/MapLocalRule.swift
Normal file
45
ProxyCore/Sources/DataLayer/Models/MapLocalRule.swift
Normal file
@@ -0,0 +1,45 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct MapLocalRule: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var name: String?
|
||||
public var urlPattern: String
|
||||
public var method: String
|
||||
public var responseStatus: Int
|
||||
public var responseHeaders: String?
|
||||
public var responseBody: Data?
|
||||
public var responseContentType: String?
|
||||
public var isEnabled: Bool
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "map_local_rules"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(
|
||||
id: Int64? = nil,
|
||||
name: String? = nil,
|
||||
urlPattern: String,
|
||||
method: String = "ANY",
|
||||
responseStatus: Int = 200,
|
||||
responseHeaders: String? = nil,
|
||||
responseBody: Data? = nil,
|
||||
responseContentType: String? = nil,
|
||||
isEnabled: Bool = true,
|
||||
createdAt: Double = Date().timeIntervalSince1970
|
||||
) {
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.urlPattern = urlPattern
|
||||
self.method = method
|
||||
self.responseStatus = responseStatus
|
||||
self.responseHeaders = responseHeaders
|
||||
self.responseBody = responseBody
|
||||
self.responseContentType = responseContentType
|
||||
self.isEnabled = isEnabled
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
22
ProxyCore/Sources/DataLayer/Models/SSLProxyingEntry.swift
Normal file
22
ProxyCore/Sources/DataLayer/Models/SSLProxyingEntry.swift
Normal file
@@ -0,0 +1,22 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public struct SSLProxyingEntry: Codable, FetchableRecord, MutablePersistableRecord, Identifiable, Sendable {
|
||||
public var id: Int64?
|
||||
public var domainPattern: String
|
||||
public var isInclude: Bool
|
||||
public var createdAt: Double
|
||||
|
||||
public static let databaseTableName = "ssl_proxying_entries"
|
||||
|
||||
public mutating func didInsert(_ inserted: InsertionSuccess) {
|
||||
id = inserted.rowID
|
||||
}
|
||||
|
||||
public init(id: Int64? = nil, domainPattern: String, isInclude: Bool, createdAt: Double = Date().timeIntervalSince1970) {
|
||||
self.id = id
|
||||
self.domainPattern = domainPattern
|
||||
self.isInclude = isInclude
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public final class ComposeRepository: Sendable {
|
||||
private let db: DatabaseManager
|
||||
|
||||
public init(db: DatabaseManager = .shared) {
|
||||
self.db = db
|
||||
}
|
||||
|
||||
public func observeRequests() -> ValueObservation<ValueReducers.Fetch<[ComposeRequest]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try ComposeRequest.order(Column("createdAt").desc).fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func insert(_ request: inout ComposeRequest) throws {
|
||||
try db.dbPool.write { db in try request.insert(db) }
|
||||
}
|
||||
|
||||
public func update(_ request: ComposeRequest) throws {
|
||||
try db.dbPool.write { db in try request.update(db) }
|
||||
}
|
||||
|
||||
public func delete(id: Int64) throws {
|
||||
try db.dbPool.write { db in _ = try ComposeRequest.deleteOne(db, id: id) }
|
||||
}
|
||||
|
||||
public func deleteAll() throws {
|
||||
try db.dbPool.write { db in _ = try ComposeRequest.deleteAll(db) }
|
||||
}
|
||||
}
|
||||
128
ProxyCore/Sources/DataLayer/Repositories/RulesRepository.swift
Normal file
128
ProxyCore/Sources/DataLayer/Repositories/RulesRepository.swift
Normal file
@@ -0,0 +1,128 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public final class RulesRepository: Sendable {
|
||||
private let db: DatabaseManager
|
||||
|
||||
public init(db: DatabaseManager = .shared) {
|
||||
self.db = db
|
||||
}
|
||||
|
||||
// MARK: - SSL Proxying
|
||||
|
||||
public func observeSSLEntries() -> ValueObservation<ValueReducers.Fetch<[SSLProxyingEntry]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try SSLProxyingEntry.order(Column("createdAt").desc).fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func fetchAllSSLEntries() throws -> [SSLProxyingEntry] {
|
||||
try db.dbPool.read { db in
|
||||
try SSLProxyingEntry.fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func insertSSLEntry(_ entry: inout SSLProxyingEntry) throws {
|
||||
try db.dbPool.write { db in try entry.insert(db) }
|
||||
}
|
||||
|
||||
public func deleteSSLEntry(id: Int64) throws {
|
||||
try db.dbPool.write { db in _ = try SSLProxyingEntry.deleteOne(db, id: id) }
|
||||
}
|
||||
|
||||
public func deleteAllSSLEntries() throws {
|
||||
try db.dbPool.write { db in _ = try SSLProxyingEntry.deleteAll(db) }
|
||||
}
|
||||
|
||||
// MARK: - Block List
|
||||
|
||||
public func observeBlockListEntries() -> ValueObservation<ValueReducers.Fetch<[BlockListEntry]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try BlockListEntry.order(Column("createdAt").desc).fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func insertBlockEntry(_ entry: inout BlockListEntry) throws {
|
||||
try db.dbPool.write { db in try entry.insert(db) }
|
||||
}
|
||||
|
||||
public func updateBlockEntry(_ entry: BlockListEntry) throws {
|
||||
try db.dbPool.write { db in try entry.update(db) }
|
||||
}
|
||||
|
||||
public func deleteBlockEntry(id: Int64) throws {
|
||||
try db.dbPool.write { db in _ = try BlockListEntry.deleteOne(db, id: id) }
|
||||
}
|
||||
|
||||
public func deleteAllBlockEntries() throws {
|
||||
try db.dbPool.write { db in _ = try BlockListEntry.deleteAll(db) }
|
||||
}
|
||||
|
||||
// MARK: - Breakpoint Rules
|
||||
|
||||
public func observeBreakpointRules() -> ValueObservation<ValueReducers.Fetch<[BreakpointRule]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try BreakpointRule.order(Column("createdAt").desc).fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func insertBreakpointRule(_ rule: inout BreakpointRule) throws {
|
||||
try db.dbPool.write { db in try rule.insert(db) }
|
||||
}
|
||||
|
||||
public func updateBreakpointRule(_ rule: BreakpointRule) throws {
|
||||
try db.dbPool.write { db in try rule.update(db) }
|
||||
}
|
||||
|
||||
public func deleteBreakpointRule(id: Int64) throws {
|
||||
try db.dbPool.write { db in _ = try BreakpointRule.deleteOne(db, id: id) }
|
||||
}
|
||||
|
||||
public func deleteAllBreakpointRules() throws {
|
||||
try db.dbPool.write { db in _ = try BreakpointRule.deleteAll(db) }
|
||||
}
|
||||
|
||||
// MARK: - Map Local Rules
|
||||
|
||||
public func observeMapLocalRules() -> ValueObservation<ValueReducers.Fetch<[MapLocalRule]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try MapLocalRule.order(Column("createdAt").desc).fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func insertMapLocalRule(_ rule: inout MapLocalRule) throws {
|
||||
try db.dbPool.write { db in try rule.insert(db) }
|
||||
}
|
||||
|
||||
public func updateMapLocalRule(_ rule: MapLocalRule) throws {
|
||||
try db.dbPool.write { db in try rule.update(db) }
|
||||
}
|
||||
|
||||
public func deleteMapLocalRule(id: Int64) throws {
|
||||
try db.dbPool.write { db in _ = try MapLocalRule.deleteOne(db, id: id) }
|
||||
}
|
||||
|
||||
public func deleteAllMapLocalRules() throws {
|
||||
try db.dbPool.write { db in _ = try MapLocalRule.deleteAll(db) }
|
||||
}
|
||||
|
||||
// MARK: - DNS Spoof Rules
|
||||
|
||||
public func observeDNSSpoofRules() -> ValueObservation<ValueReducers.Fetch<[DNSSpoofRule]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try DNSSpoofRule.order(Column("createdAt").desc).fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func insertDNSSpoofRule(_ rule: inout DNSSpoofRule) throws {
|
||||
try db.dbPool.write { db in try rule.insert(db) }
|
||||
}
|
||||
|
||||
public func deleteDNSSpoofRule(id: Int64) throws {
|
||||
try db.dbPool.write { db in _ = try DNSSpoofRule.deleteOne(db, id: id) }
|
||||
}
|
||||
|
||||
public func deleteAllDNSSpoofRules() throws {
|
||||
try db.dbPool.write { db in _ = try DNSSpoofRule.deleteAll(db) }
|
||||
}
|
||||
}
|
||||
110
ProxyCore/Sources/DataLayer/Repositories/TrafficRepository.swift
Normal file
110
ProxyCore/Sources/DataLayer/Repositories/TrafficRepository.swift
Normal file
@@ -0,0 +1,110 @@
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
public final class TrafficRepository: Sendable {
|
||||
private let db: DatabaseManager
|
||||
|
||||
public init(db: DatabaseManager = .shared) {
|
||||
self.db = db
|
||||
}
|
||||
|
||||
// MARK: - Domain Groups
|
||||
|
||||
public func observeDomainGroups() -> ValueObservation<ValueReducers.Fetch<[DomainGroup]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try DomainGroup.fetchAll(db, sql: """
|
||||
SELECT domain, COUNT(*) as requestCount
|
||||
FROM captured_traffic
|
||||
WHERE isHidden = 0
|
||||
GROUP BY domain
|
||||
ORDER BY MAX(startedAt) DESC
|
||||
""")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Traffic for Domain
|
||||
|
||||
public func observeTraffic(forDomain domain: String) -> ValueObservation<ValueReducers.Fetch<[CapturedTraffic]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try CapturedTraffic
|
||||
.filter(Column("domain") == domain)
|
||||
.filter(Column("isHidden") == false)
|
||||
.order(Column("startedAt").desc)
|
||||
.fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Pinned
|
||||
|
||||
public func observePinnedTraffic() -> ValueObservation<ValueReducers.Fetch<[CapturedTraffic]>> {
|
||||
ValueObservation.tracking { db in
|
||||
try CapturedTraffic
|
||||
.filter(Column("isPinned") == true)
|
||||
.order(Column("startedAt").desc)
|
||||
.fetchAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Single Request
|
||||
|
||||
public func traffic(byId id: Int64) throws -> CapturedTraffic? {
|
||||
try db.dbPool.read { db in
|
||||
try CapturedTraffic.fetchOne(db, id: id)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Write Operations
|
||||
|
||||
public func insert(_ traffic: inout CapturedTraffic) throws {
|
||||
try db.dbPool.write { db in
|
||||
try traffic.insert(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func updateResponse(
|
||||
requestId: String,
|
||||
statusCode: Int,
|
||||
statusText: String,
|
||||
responseHeaders: String?,
|
||||
responseBody: Data?,
|
||||
responseBodySize: Int,
|
||||
responseContentType: String?,
|
||||
completedAt: Double,
|
||||
durationMs: Int
|
||||
) throws {
|
||||
try db.dbPool.write { db in
|
||||
try db.execute(sql: """
|
||||
UPDATE captured_traffic SET
|
||||
statusCode = ?, statusText = ?,
|
||||
responseHeaders = ?, responseBody = ?,
|
||||
responseBodySize = ?, responseContentType = ?,
|
||||
completedAt = ?, durationMs = ?
|
||||
WHERE requestId = ?
|
||||
""", arguments: [
|
||||
statusCode, statusText,
|
||||
responseHeaders, responseBody,
|
||||
responseBodySize, responseContentType,
|
||||
completedAt, durationMs,
|
||||
requestId
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
public func togglePin(id: Int64, isPinned: Bool) throws {
|
||||
try db.dbPool.write { db in
|
||||
try db.execute(sql: "UPDATE captured_traffic SET isPinned = ? WHERE id = ?", arguments: [isPinned, id])
|
||||
}
|
||||
}
|
||||
|
||||
public func deleteAll() throws {
|
||||
try db.dbPool.write { db in
|
||||
_ = try CapturedTraffic.deleteAll(db)
|
||||
}
|
||||
}
|
||||
|
||||
public func deleteForDomain(_ domain: String) throws {
|
||||
try db.dbPool.write { db in
|
||||
_ = try CapturedTraffic.filter(Column("domain") == domain).deleteAll(db)
|
||||
}
|
||||
}
|
||||
}
|
||||
290
ProxyCore/Sources/ProxyEngine/CertificateManager.swift
Normal file
290
ProxyCore/Sources/ProxyEngine/CertificateManager.swift
Normal file
@@ -0,0 +1,290 @@
|
||||
import Foundation
|
||||
import X509
|
||||
import Crypto
|
||||
import SwiftASN1
|
||||
import NIOSSL
|
||||
#if canImport(UIKit)
|
||||
import UIKit
|
||||
#endif
|
||||
|
||||
/// Manages root CA generation, leaf certificate signing, and an LRU certificate cache.
|
||||
public final class CertificateManager: @unchecked Sendable {
|
||||
public static let shared = CertificateManager()
|
||||
|
||||
private let lock = NSLock()
|
||||
private var rootCAKey: P256.Signing.PrivateKey?
|
||||
private var rootCACert: Certificate?
|
||||
private var rootCANIOSSL: NIOSSLCertificate?
|
||||
|
||||
// LRU cache for generated leaf certificates
|
||||
private var certCache: [String: (NIOSSLCertificate, NIOSSLPrivateKey)] = [:]
|
||||
private var cacheOrder: [String] = []
|
||||
|
||||
private let keychainCAKeyTag = "com.treyt.proxyapp.ca.privatekey"
|
||||
private let keychainCACertTag = "com.treyt.proxyapp.ca.cert"
|
||||
|
||||
private init() {
|
||||
loadOrGenerateCA()
|
||||
}
|
||||
|
||||
// MARK: - Public API
|
||||
|
||||
public var hasCA: Bool {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
return rootCACert != nil
|
||||
}
|
||||
|
||||
/// Get or generate a leaf certificate + key for the given domain.
|
||||
public func tlsServerContext(for domain: String) throws -> NIOSSLContext {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
|
||||
if let cached = certCache[domain] {
|
||||
cacheOrder.removeAll { $0 == domain }
|
||||
cacheOrder.append(domain)
|
||||
return try makeServerContext(cert: cached.0, key: cached.1)
|
||||
}
|
||||
|
||||
guard let caKey = rootCAKey, let caCert = rootCACert else {
|
||||
throw CertificateError.caNotFound
|
||||
}
|
||||
|
||||
let (leafCert, leafKey) = try generateLeaf(domain: domain, caKey: caKey, caCert: caCert)
|
||||
|
||||
// Serialize to DER/PEM for NIOSSL
|
||||
var serializer = DER.Serializer()
|
||||
try leafCert.serialize(into: &serializer)
|
||||
let leafDER = serializer.serializedBytes
|
||||
let nioLeafCert = try NIOSSLCertificate(bytes: leafDER, format: .der)
|
||||
let leafKeyPEM = leafKey.pemRepresentation
|
||||
let nioLeafKey = try NIOSSLPrivateKey(bytes: [UInt8](leafKeyPEM.utf8), format: .pem)
|
||||
|
||||
certCache[domain] = (nioLeafCert, nioLeafKey)
|
||||
cacheOrder.append(domain)
|
||||
|
||||
while cacheOrder.count > ProxyConstants.certificateCacheSize {
|
||||
let evicted = cacheOrder.removeFirst()
|
||||
certCache.removeValue(forKey: evicted)
|
||||
}
|
||||
|
||||
return try makeServerContext(cert: nioLeafCert, key: nioLeafKey)
|
||||
}
|
||||
|
||||
/// Export the root CA as DER data for user installation.
|
||||
public func exportCACertificateDER() -> [UInt8]? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
guard let cert = rootCACert else { return nil }
|
||||
var serializer = DER.Serializer()
|
||||
try? cert.serialize(into: &serializer)
|
||||
return serializer.serializedBytes
|
||||
}
|
||||
|
||||
/// Export as PEM for display.
|
||||
public func exportCACertificatePEM() -> String? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
guard let cert = rootCACert else { return nil }
|
||||
guard let pem = try? cert.serializeAsPEM() else { return nil }
|
||||
return pem.pemString
|
||||
}
|
||||
|
||||
public var caNotValidAfter: Date? {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
// notValidAfter is a Time, not directly a Date — we stored the date when generating
|
||||
return nil // Will be set properly after we store dates
|
||||
}
|
||||
|
||||
// MARK: - CA Generation
|
||||
|
||||
private func loadOrGenerateCA() {
|
||||
if loadCAFromKeychain() { return }
|
||||
|
||||
do {
|
||||
let key = P256.Signing.PrivateKey()
|
||||
let name = try DistinguishedName {
|
||||
CommonName("Proxy CA (\(deviceName()))")
|
||||
OrganizationName("ProxyApp")
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let twoYearsLater = now.addingTimeInterval(365 * 24 * 3600 * 2)
|
||||
|
||||
let extensions = try Certificate.Extensions {
|
||||
Critical(BasicConstraints.isCertificateAuthority(maxPathLength: 0))
|
||||
Critical(KeyUsage(keyCertSign: true, cRLSign: true))
|
||||
}
|
||||
|
||||
let cert = try Certificate(
|
||||
version: .v3,
|
||||
serialNumber: Certificate.SerialNumber(),
|
||||
publicKey: .init(key.publicKey),
|
||||
notValidBefore: now,
|
||||
notValidAfter: twoYearsLater,
|
||||
issuer: name,
|
||||
subject: name,
|
||||
signatureAlgorithm: .ecdsaWithSHA256,
|
||||
extensions: extensions,
|
||||
issuerPrivateKey: .init(key)
|
||||
)
|
||||
|
||||
self.rootCAKey = key
|
||||
self.rootCACert = cert
|
||||
|
||||
var serializer = DER.Serializer()
|
||||
try cert.serialize(into: &serializer)
|
||||
let der = serializer.serializedBytes
|
||||
self.rootCANIOSSL = try NIOSSLCertificate(bytes: der, format: .der)
|
||||
|
||||
saveCAToKeychain(key: key, certDER: der)
|
||||
print("[CertificateManager] Generated new root CA")
|
||||
} catch {
|
||||
print("[CertificateManager] Failed to generate CA: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Leaf Certificate Generation
|
||||
|
||||
private func generateLeaf(
|
||||
domain: String,
|
||||
caKey: P256.Signing.PrivateKey,
|
||||
caCert: Certificate
|
||||
) throws -> (Certificate, P256.Signing.PrivateKey) {
|
||||
let leafKey = P256.Signing.PrivateKey()
|
||||
let now = Date()
|
||||
let oneYearLater = now.addingTimeInterval(365 * 24 * 3600)
|
||||
|
||||
let extensions = try Certificate.Extensions {
|
||||
Critical(BasicConstraints.notCertificateAuthority)
|
||||
Critical(KeyUsage(digitalSignature: true))
|
||||
try ExtendedKeyUsage([.serverAuth])
|
||||
SubjectAlternativeNames([.dnsName(domain)])
|
||||
}
|
||||
|
||||
let leafName = try DistinguishedName {
|
||||
CommonName(domain)
|
||||
OrganizationName("ProxyApp")
|
||||
}
|
||||
|
||||
let cert = try Certificate(
|
||||
version: .v3,
|
||||
serialNumber: Certificate.SerialNumber(),
|
||||
publicKey: .init(leafKey.publicKey),
|
||||
notValidBefore: now,
|
||||
notValidAfter: oneYearLater,
|
||||
issuer: caCert.subject,
|
||||
subject: leafName,
|
||||
signatureAlgorithm: .ecdsaWithSHA256,
|
||||
extensions: extensions,
|
||||
issuerPrivateKey: .init(caKey)
|
||||
)
|
||||
|
||||
return (cert, leafKey)
|
||||
}
|
||||
|
||||
// MARK: - TLS Context
|
||||
|
||||
private func makeServerContext(cert: NIOSSLCertificate, key: NIOSSLPrivateKey) throws -> NIOSSLContext {
|
||||
var certs = [cert]
|
||||
if let caCert = rootCANIOSSL {
|
||||
certs.append(caCert)
|
||||
}
|
||||
var config = TLSConfiguration.makeServerConfiguration(
|
||||
certificateChain: certs.map { .certificate($0) },
|
||||
privateKey: .privateKey(key)
|
||||
)
|
||||
config.applicationProtocols = ["http/1.1"]
|
||||
return try NIOSSLContext(configuration: config)
|
||||
}
|
||||
|
||||
// MARK: - Keychain
|
||||
|
||||
private func loadCAFromKeychain() -> Bool {
|
||||
let keyQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCAKeyTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecReturnData as String: true
|
||||
]
|
||||
var keyResult: AnyObject?
|
||||
guard SecItemCopyMatching(keyQuery as CFDictionary, &keyResult) == errSecSuccess,
|
||||
let keyData = keyResult as? Data else { return false }
|
||||
|
||||
let certQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCACertTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecReturnData as String: true
|
||||
]
|
||||
var certResult: AnyObject?
|
||||
guard SecItemCopyMatching(certQuery as CFDictionary, &certResult) == errSecSuccess,
|
||||
let certData = certResult as? Data else { return false }
|
||||
|
||||
do {
|
||||
let key = try P256.Signing.PrivateKey(rawRepresentation: keyData)
|
||||
let cert = try Certificate(derEncoded: [UInt8](certData))
|
||||
let nioCert = try NIOSSLCertificate(bytes: [UInt8](certData), format: .der)
|
||||
|
||||
self.rootCAKey = key
|
||||
self.rootCACert = cert
|
||||
self.rootCANIOSSL = nioCert
|
||||
print("[CertificateManager] Loaded CA from Keychain")
|
||||
return true
|
||||
} catch {
|
||||
print("[CertificateManager] Failed to load CA from Keychain: \(error)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func saveCAToKeychain(key: P256.Signing.PrivateKey, certDER: [UInt8]) {
|
||||
let keyData = key.rawRepresentation
|
||||
|
||||
// Delete existing entries
|
||||
for tag in [keychainCAKeyTag, keychainCACertTag] {
|
||||
let deleteQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: tag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier
|
||||
]
|
||||
SecItemDelete(deleteQuery as CFDictionary)
|
||||
}
|
||||
|
||||
// Save key
|
||||
let addKeyQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCAKeyTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecValueData as String: keyData,
|
||||
kSecAttrAccessible as String: kSecAttrAccessibleAfterFirstUnlock
|
||||
]
|
||||
SecItemAdd(addKeyQuery as CFDictionary, nil)
|
||||
|
||||
// Save cert
|
||||
let addCertQuery: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrService as String: keychainCACertTag,
|
||||
kSecAttrAccessGroup as String: ProxyConstants.appGroupIdentifier,
|
||||
kSecValueData as String: Data(certDER),
|
||||
kSecAttrAccessible as String: kSecAttrAccessibleAfterFirstUnlock
|
||||
]
|
||||
SecItemAdd(addCertQuery as CFDictionary, nil)
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func deviceName() -> String {
|
||||
#if canImport(UIKit)
|
||||
return UIDevice.current.name
|
||||
#else
|
||||
return Host.current().localizedName ?? "Unknown"
|
||||
#endif
|
||||
}
|
||||
|
||||
public enum CertificateError: Error {
|
||||
case notImplemented
|
||||
case generationFailed
|
||||
case caNotFound
|
||||
}
|
||||
}
|
||||
298
ProxyCore/Sources/ProxyEngine/ConnectHandler.swift
Normal file
298
ProxyCore/Sources/ProxyEngine/ConnectHandler.swift
Normal file
@@ -0,0 +1,298 @@
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
import NIOHTTP1
|
||||
|
||||
/// Handles incoming proxy requests:
|
||||
/// - HTTP CONNECT → establishes TCP tunnel (GlueHandler passthrough, or MITM in Phase 3)
|
||||
/// - Plain HTTP → connects upstream, forwards request, captures request+response
|
||||
final class ConnectHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPServerRequestPart
|
||||
typealias OutboundOut = HTTPServerResponsePart
|
||||
|
||||
private let trafficRepo: TrafficRepository
|
||||
|
||||
// Buffer request parts until we've connected upstream
|
||||
private var pendingHead: HTTPRequestHead?
|
||||
private var pendingBody: [ByteBuffer] = []
|
||||
private var pendingEnd: HTTPHeaders?
|
||||
private var receivedEnd = false
|
||||
|
||||
init(trafficRepo: TrafficRepository) {
|
||||
self.trafficRepo = trafficRepo
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
if head.method == .CONNECT {
|
||||
handleConnect(context: context, head: head)
|
||||
} else {
|
||||
pendingHead = head
|
||||
}
|
||||
case .body(let buffer):
|
||||
pendingBody.append(buffer)
|
||||
case .end(let trailers):
|
||||
if pendingHead != nil {
|
||||
pendingEnd = trailers
|
||||
receivedEnd = true
|
||||
handleHTTPRequest(context: context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - CONNECT (HTTPS tunnel)
|
||||
|
||||
private func handleConnect(context: ChannelHandlerContext, head: HTTPRequestHead) {
|
||||
let components = head.uri.split(separator: ":")
|
||||
let host = String(components[0])
|
||||
let port = components.count > 1 ? Int(components[1]) ?? 443 : 443
|
||||
|
||||
// Check if this domain should be MITM'd (SSL Proxying enabled + domain in include list)
|
||||
let shouldMITM = shouldInterceptSSL(domain: host)
|
||||
|
||||
// Send 200 Connection Established
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
|
||||
if shouldMITM {
|
||||
// MITM mode: strip HTTP handlers, install MITMHandler
|
||||
setupMITM(context: context, host: host, port: port)
|
||||
} else {
|
||||
// Passthrough mode: record domain-level entry, tunnel raw bytes
|
||||
recordConnectTraffic(host: host, port: port)
|
||||
|
||||
// We don't need to connect upstream ourselves — GlueHandler does raw forwarding
|
||||
// But GlueHandler pairs two channels, so we need the remote channel first
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.connect(host: host, port: port)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let remoteChannel):
|
||||
self.setupGlue(context: context, remoteChannel: remoteChannel)
|
||||
case .failure(let error):
|
||||
print("[Proxy] CONNECT passthrough failed to \(host):\(port): \(error)")
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func shouldInterceptSSL(domain: String) -> Bool {
|
||||
guard IPCManager.shared.isSSLProxyingEnabled else { return false }
|
||||
guard CertificateManager.shared.hasCA else { return false }
|
||||
|
||||
// Check SSL proxying list from database
|
||||
let rulesRepo = RulesRepository()
|
||||
do {
|
||||
let entries = try rulesRepo.fetchAllSSLEntries()
|
||||
|
||||
// Check exclude list first
|
||||
for entry in entries where !entry.isInclude {
|
||||
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check include list
|
||||
for entry in entries where entry.isInclude {
|
||||
if WildcardMatcher.matches(domain, pattern: entry.domainPattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[Proxy] Failed to check SSL proxying list: \(error)")
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
private func setupMITM(context: ChannelHandlerContext, host: String, port: Int) {
|
||||
let mitmHandler = MITMHandler(host: host, port: port, trafficRepo: trafficRepo)
|
||||
|
||||
// Remove HTTP handlers, keep raw bytes for MITMHandler
|
||||
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
|
||||
.whenSuccess { handler in
|
||||
context.channel.pipeline.removeHandler(handler, promise: nil)
|
||||
}
|
||||
|
||||
context.pipeline.removeHandler(context: context).whenComplete { _ in
|
||||
context.channel.pipeline.addHandler(mitmHandler).whenFailure { error in
|
||||
print("[Proxy] Failed to install MITM handler: \(error)")
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func setupGlue(context: ChannelHandlerContext, remoteChannel: Channel) {
|
||||
let localGlue = GlueHandler()
|
||||
let remoteGlue = GlueHandler()
|
||||
localGlue.partner = remoteGlue
|
||||
remoteGlue.partner = localGlue
|
||||
|
||||
// Remove all HTTP handlers from the client channel, leaving raw bytes
|
||||
context.channel.pipeline.handler(type: ByteToMessageHandler<HTTPRequestDecoder>.self)
|
||||
.whenSuccess { handler in
|
||||
context.channel.pipeline.removeHandler(handler, promise: nil)
|
||||
}
|
||||
|
||||
context.pipeline.removeHandler(context: context).whenComplete { _ in
|
||||
context.channel.pipeline.addHandler(localGlue).whenSuccess {
|
||||
remoteChannel.pipeline.addHandler(remoteGlue).whenFailure { _ in
|
||||
context.close(promise: nil)
|
||||
remoteChannel.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Plain HTTP forwarding
|
||||
|
||||
private func handleHTTPRequest(context: ChannelHandlerContext) {
|
||||
guard let head = pendingHead else { return }
|
||||
|
||||
// Parse host and port from the absolute URI or Host header
|
||||
guard let (host, port, path) = parseHTTPTarget(head: head) else {
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .badRequest)
|
||||
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite the request URI to relative path (upstream expects /path, not http://host/path)
|
||||
var upstreamHead = head
|
||||
upstreamHead.uri = path
|
||||
// Ensure Host header is set
|
||||
if !upstreamHead.headers.contains(name: "Host") {
|
||||
upstreamHead.headers.add(name: "Host", value: host)
|
||||
}
|
||||
|
||||
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: host, scheme: "http")
|
||||
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
// Remote channel: decode HTTP responses, encode HTTP requests
|
||||
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
|
||||
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(captureHandler)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(
|
||||
HTTPRelayHandler(clientContext: context, wrapResponse: self.wrapOutboundOut)
|
||||
)
|
||||
}
|
||||
}
|
||||
.connect(host: host, port: port)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let remoteChannel):
|
||||
// Forward the buffered request to upstream
|
||||
remoteChannel.write(NIOAny(HTTPClientRequestPart.head(upstreamHead)), promise: nil)
|
||||
for bodyBuffer in self.pendingBody {
|
||||
remoteChannel.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(bodyBuffer))), promise: nil)
|
||||
}
|
||||
remoteChannel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(self.pendingEnd)), promise: nil)
|
||||
|
||||
// Clear buffered data
|
||||
self.pendingHead = nil
|
||||
self.pendingBody.removeAll()
|
||||
self.pendingEnd = nil
|
||||
|
||||
case .failure(let error):
|
||||
print("[Proxy] HTTP forward failed to \(host):\(port): \(error)")
|
||||
let responseHead = HTTPResponseHead(version: .http1_1, status: .badGateway)
|
||||
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
|
||||
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - URL Parsing
|
||||
|
||||
private func parseHTTPTarget(head: HTTPRequestHead) -> (host: String, port: Int, path: String)? {
|
||||
// Absolute URI: "http://example.com:8080/path?query"
|
||||
if head.uri.hasPrefix("http://") || head.uri.hasPrefix("https://") {
|
||||
guard let url = URLComponents(string: head.uri) else { return nil }
|
||||
let host = url.host ?? ""
|
||||
let port = url.port ?? (head.uri.hasPrefix("https") ? 443 : 80)
|
||||
var path = url.path.isEmpty ? "/" : url.path
|
||||
if let query = url.query {
|
||||
path += "?\(query)"
|
||||
}
|
||||
return (host, port, path)
|
||||
}
|
||||
|
||||
// Relative URI with Host header
|
||||
if let hostHeader = head.headers.first(name: "Host") {
|
||||
let parts = hostHeader.split(separator: ":")
|
||||
let host = String(parts[0])
|
||||
let port = parts.count > 1 ? Int(parts[1]) ?? 80 : 80
|
||||
return (host, port, head.uri)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - CONNECT traffic recording
|
||||
|
||||
private func recordConnectTraffic(host: String, port: Int) {
|
||||
var traffic = CapturedTraffic(
|
||||
domain: host,
|
||||
url: "https://\(host):\(port)",
|
||||
method: "CONNECT",
|
||||
scheme: "https",
|
||||
statusCode: 200,
|
||||
statusText: "Connection Established",
|
||||
startedAt: Date().timeIntervalSince1970,
|
||||
completedAt: Date().timeIntervalSince1970,
|
||||
durationMs: 0,
|
||||
isSslDecrypted: false
|
||||
)
|
||||
try? trafficRepo.insert(&traffic)
|
||||
IPCManager.shared.post(.newTrafficCaptured)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - HTTPRelayHandler
|
||||
|
||||
/// Relays HTTP responses from the upstream server back to the proxy client.
|
||||
final class HTTPRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPClientResponsePart
|
||||
|
||||
private let clientContext: ChannelHandlerContext
|
||||
private let wrapResponse: (HTTPServerResponsePart) -> NIOAny
|
||||
|
||||
init(clientContext: ChannelHandlerContext, wrapResponse: @escaping (HTTPServerResponsePart) -> NIOAny) {
|
||||
self.clientContext = clientContext
|
||||
self.wrapResponse = wrapResponse
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
let serverHead = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
|
||||
clientContext.write(wrapResponse(.head(serverHead)), promise: nil)
|
||||
case .body(let buffer):
|
||||
clientContext.write(wrapResponse(.body(.byteBuffer(buffer))), promise: nil)
|
||||
case .end(let trailers):
|
||||
clientContext.writeAndFlush(wrapResponse(.end(trailers)), promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
print("[Proxy] Relay error: \(error)")
|
||||
context.close(promise: nil)
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
}
|
||||
66
ProxyCore/Sources/ProxyEngine/GlueHandler.swift
Normal file
66
ProxyCore/Sources/ProxyEngine/GlueHandler.swift
Normal file
@@ -0,0 +1,66 @@
|
||||
import Foundation
|
||||
import NIOCore
|
||||
|
||||
/// Bidirectional TCP forwarder. Pairs two channels so bytes flow in both directions.
|
||||
/// Used for CONNECT tunneling (passthrough mode, no MITM).
|
||||
final class GlueHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias OutboundOut = ByteBuffer
|
||||
|
||||
var partner: GlueHandler?
|
||||
private var context: ChannelHandlerContext?
|
||||
private var pendingRead = false
|
||||
|
||||
func handlerAdded(context: ChannelHandlerContext) {
|
||||
self.context = context
|
||||
}
|
||||
|
||||
func handlerRemoved(context: ChannelHandlerContext) {
|
||||
self.context = nil
|
||||
self.partner = nil
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
partner?.write(unwrapInboundIn(data))
|
||||
}
|
||||
|
||||
func channelReadComplete(context: ChannelHandlerContext) {
|
||||
partner?.flush()
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
partner?.close()
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
context.close(promise: nil)
|
||||
}
|
||||
|
||||
func channelWritabilityChanged(context: ChannelHandlerContext) {
|
||||
if context.channel.isWritable {
|
||||
partner?.read()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Partner operations
|
||||
|
||||
private func write(_ buffer: ByteBuffer) {
|
||||
context?.write(wrapOutboundOut(buffer), promise: nil)
|
||||
}
|
||||
|
||||
private func flush() {
|
||||
context?.flush()
|
||||
}
|
||||
|
||||
private func read() {
|
||||
if let context, !pendingRead {
|
||||
pendingRead = true
|
||||
context.read()
|
||||
pendingRead = false
|
||||
}
|
||||
}
|
||||
|
||||
private func close() {
|
||||
context?.close(promise: nil)
|
||||
}
|
||||
}
|
||||
141
ProxyCore/Sources/ProxyEngine/HTTPCaptureHandler.swift
Normal file
141
ProxyCore/Sources/ProxyEngine/HTTPCaptureHandler.swift
Normal file
@@ -0,0 +1,141 @@
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOHTTP1
|
||||
|
||||
/// Captures HTTP request/response pairs and writes them to the traffic database.
|
||||
/// Inserted into the pipeline after TLS termination (MITM) or for plain HTTP.
|
||||
final class HTTPCaptureHandler: ChannelDuplexHandler {
|
||||
typealias InboundIn = HTTPClientResponsePart
|
||||
typealias InboundOut = HTTPClientResponsePart
|
||||
typealias OutboundIn = HTTPClientRequestPart
|
||||
typealias OutboundOut = HTTPClientRequestPart
|
||||
|
||||
private let trafficRepo: TrafficRepository
|
||||
private let domain: String
|
||||
private let scheme: String
|
||||
|
||||
private var currentRequestId: String?
|
||||
private var requestHead: HTTPRequestHead?
|
||||
private var requestBody = Data()
|
||||
private var responseHead: HTTPResponseHead?
|
||||
private var responseBody = Data()
|
||||
private var requestStartTime: Double = 0
|
||||
|
||||
init(trafficRepo: TrafficRepository, domain: String, scheme: String = "https") {
|
||||
self.trafficRepo = trafficRepo
|
||||
self.domain = domain
|
||||
self.scheme = scheme
|
||||
}
|
||||
|
||||
// MARK: - Outbound (Request)
|
||||
|
||||
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||
let part = unwrapOutboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
currentRequestId = UUID().uuidString
|
||||
requestHead = head
|
||||
requestBody = Data()
|
||||
requestStartTime = Date().timeIntervalSince1970
|
||||
case .body(.byteBuffer(let buffer)):
|
||||
if requestBody.count < ProxyConstants.maxBodySizeBytes {
|
||||
requestBody.append(contentsOf: buffer.readableBytesView)
|
||||
}
|
||||
case .end:
|
||||
saveRequest()
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
context.write(data, promise: promise)
|
||||
}
|
||||
|
||||
// MARK: - Inbound (Response)
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
responseHead = head
|
||||
responseBody = Data()
|
||||
case .body(let buffer):
|
||||
if responseBody.count < ProxyConstants.maxBodySizeBytes {
|
||||
responseBody.append(contentsOf: buffer.readableBytesView)
|
||||
}
|
||||
case .end:
|
||||
saveResponse()
|
||||
}
|
||||
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private func saveRequest() {
|
||||
guard let head = requestHead, let reqId = currentRequestId else { return }
|
||||
|
||||
let url = "\(scheme)://\(domain)\(head.uri)"
|
||||
let headersJSON = encodeHeaders(head.headers)
|
||||
let queryParams = extractQueryParams(from: head.uri)
|
||||
|
||||
var traffic = CapturedTraffic(
|
||||
requestId: reqId,
|
||||
domain: domain,
|
||||
url: url,
|
||||
method: head.method.rawValue,
|
||||
scheme: scheme,
|
||||
requestHeaders: headersJSON,
|
||||
requestBody: requestBody.isEmpty ? nil : requestBody,
|
||||
requestBodySize: requestBody.count,
|
||||
requestContentType: head.headers.first(name: "Content-Type"),
|
||||
queryParameters: queryParams,
|
||||
startedAt: requestStartTime,
|
||||
isSslDecrypted: scheme == "https"
|
||||
)
|
||||
|
||||
try? trafficRepo.insert(&traffic)
|
||||
}
|
||||
|
||||
private func saveResponse() {
|
||||
guard let reqId = currentRequestId, let head = responseHead else { return }
|
||||
|
||||
let now = Date().timeIntervalSince1970
|
||||
let durationMs = Int((now - requestStartTime) * 1000)
|
||||
|
||||
try? trafficRepo.updateResponse(
|
||||
requestId: reqId,
|
||||
statusCode: Int(head.status.code),
|
||||
statusText: head.status.reasonPhrase,
|
||||
responseHeaders: encodeHeaders(head.headers),
|
||||
responseBody: responseBody.isEmpty ? nil : responseBody,
|
||||
responseBodySize: responseBody.count,
|
||||
responseContentType: head.headers.first(name: "Content-Type"),
|
||||
completedAt: now,
|
||||
durationMs: durationMs
|
||||
)
|
||||
|
||||
IPCManager.shared.post(.newTrafficCaptured)
|
||||
}
|
||||
|
||||
private func encodeHeaders(_ headers: HTTPHeaders) -> String? {
|
||||
var dict: [String: String] = [:]
|
||||
for (name, value) in headers {
|
||||
dict[name] = value
|
||||
}
|
||||
guard let data = try? JSONEncoder().encode(dict) else { return nil }
|
||||
return String(data: data, encoding: .utf8)
|
||||
}
|
||||
|
||||
private func extractQueryParams(from uri: String) -> String? {
|
||||
guard let url = URLComponents(string: uri),
|
||||
let items = url.queryItems, !items.isEmpty else { return nil }
|
||||
var dict: [String: String] = [:]
|
||||
for item in items {
|
||||
dict[item.name] = item.value ?? ""
|
||||
}
|
||||
guard let data = try? JSONEncoder().encode(dict) else { return nil }
|
||||
return String(data: data, encoding: .utf8)
|
||||
}
|
||||
}
|
||||
294
ProxyCore/Sources/ProxyEngine/MITMHandler.swift
Normal file
294
ProxyCore/Sources/ProxyEngine/MITMHandler.swift
Normal file
@@ -0,0 +1,294 @@
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
import NIOSSL
|
||||
import NIOHTTP1
|
||||
|
||||
/// After a CONNECT tunnel is established, this handler:
|
||||
/// 1. Reads the first bytes from the client to extract the SNI hostname from the TLS ClientHello
|
||||
/// 2. Generates a per-domain leaf certificate via CertificateManager
|
||||
/// 3. Terminates client-side TLS with the generated cert
|
||||
/// 4. Initiates server-side TLS to the real server
|
||||
/// 5. Installs HTTP codecs + HTTPCaptureHandler on both sides to capture decrypted traffic
|
||||
final class MITMHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
private let host: String
|
||||
private let port: Int
|
||||
private let trafficRepo: TrafficRepository
|
||||
private let certManager: CertificateManager
|
||||
|
||||
init(host: String, port: Int, trafficRepo: TrafficRepository, certManager: CertificateManager = .shared) {
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.trafficRepo = trafficRepo
|
||||
self.certManager = certManager
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
var buffer = unwrapInboundIn(data)
|
||||
|
||||
// Extract SNI from ClientHello if possible, otherwise use the CONNECT host
|
||||
let sniDomain = extractSNI(from: buffer) ?? host
|
||||
|
||||
// Remove this handler — we'll rebuild the pipeline
|
||||
context.pipeline.removeHandler(self, promise: nil)
|
||||
|
||||
// Get TLS context for this domain
|
||||
let sslContext: NIOSSLContext
|
||||
do {
|
||||
sslContext = try certManager.tlsServerContext(for: sniDomain)
|
||||
} catch {
|
||||
print("[MITM] Failed to get TLS context for \(sniDomain): \(error)")
|
||||
context.close(promise: nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Add server-side TLS handler (we are the "server" to the client)
|
||||
let sslServerHandler = NIOSSLServerHandler(context: sslContext)
|
||||
let trafficRepo = self.trafficRepo
|
||||
let host = self.host
|
||||
let port = self.port
|
||||
|
||||
context.channel.pipeline.addHandler(sslServerHandler, position: .first).flatMap {
|
||||
// Add HTTP codec after TLS
|
||||
context.channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))
|
||||
}.flatMap {
|
||||
context.channel.pipeline.addHandler(HTTPResponseEncoder())
|
||||
}.flatMap {
|
||||
// Add the forwarding handler that connects to the real server
|
||||
context.channel.pipeline.addHandler(
|
||||
MITMForwardHandler(
|
||||
remoteHost: host,
|
||||
remotePort: port,
|
||||
domain: sniDomain,
|
||||
trafficRepo: trafficRepo
|
||||
)
|
||||
)
|
||||
}.whenComplete { result in
|
||||
switch result {
|
||||
case .success:
|
||||
// Re-fire the original ClientHello bytes so TLS handshake proceeds
|
||||
context.channel.pipeline.fireChannelRead(NIOAny(buffer))
|
||||
case .failure(let error):
|
||||
print("[MITM] Pipeline setup failed: \(error)")
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - SNI Extraction
|
||||
|
||||
/// Parse the SNI hostname from a TLS ClientHello message.
|
||||
private func extractSNI(from buffer: ByteBuffer) -> String? {
|
||||
var buf = buffer
|
||||
guard buf.readableBytes >= 43 else { return nil }
|
||||
|
||||
// TLS record header
|
||||
guard buf.readInteger(as: UInt8.self) == 0x16 else { return nil } // Handshake
|
||||
let _ = buf.readInteger(as: UInt16.self) // Version
|
||||
let _ = buf.readInteger(as: UInt16.self) // Length
|
||||
|
||||
// Handshake header
|
||||
guard buf.readInteger(as: UInt8.self) == 0x01 else { return nil } // ClientHello
|
||||
let _ = buf.readBytes(length: 3) // Length (3 bytes)
|
||||
|
||||
// Client version
|
||||
let _ = buf.readInteger(as: UInt16.self)
|
||||
// Random (32 bytes)
|
||||
guard buf.readBytes(length: 32) != nil else { return nil }
|
||||
// Session ID
|
||||
guard let sessionIdLen = buf.readInteger(as: UInt8.self) else { return nil }
|
||||
guard buf.readBytes(length: Int(sessionIdLen)) != nil else { return nil }
|
||||
// Cipher suites
|
||||
guard let cipherSuitesLen = buf.readInteger(as: UInt16.self) else { return nil }
|
||||
guard buf.readBytes(length: Int(cipherSuitesLen)) != nil else { return nil }
|
||||
// Compression methods
|
||||
guard let compMethodsLen = buf.readInteger(as: UInt8.self) else { return nil }
|
||||
guard buf.readBytes(length: Int(compMethodsLen)) != nil else { return nil }
|
||||
|
||||
// Extensions
|
||||
guard let extensionsLen = buf.readInteger(as: UInt16.self) else { return nil }
|
||||
var extensionsRemaining = Int(extensionsLen)
|
||||
|
||||
while extensionsRemaining > 4 {
|
||||
guard let extType = buf.readInteger(as: UInt16.self),
|
||||
let extLen = buf.readInteger(as: UInt16.self) else { return nil }
|
||||
extensionsRemaining -= 4 + Int(extLen)
|
||||
|
||||
if extType == 0x0000 { // SNI extension
|
||||
guard let _ = buf.readInteger(as: UInt16.self), // SNI list length
|
||||
let nameType = buf.readInteger(as: UInt8.self),
|
||||
nameType == 0x00, // hostname
|
||||
let nameLen = buf.readInteger(as: UInt16.self),
|
||||
let nameBytes = buf.readBytes(length: Int(nameLen)) else {
|
||||
return nil
|
||||
}
|
||||
return String(bytes: nameBytes, encoding: .utf8)
|
||||
} else {
|
||||
guard buf.readBytes(length: Int(extLen)) != nil else { return nil }
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - MITMForwardHandler
|
||||
|
||||
/// Handles decrypted HTTP from the client, forwards to the real server over TLS,
|
||||
/// and relays responses back. Captures everything via HTTPCaptureHandler.
|
||||
final class MITMForwardHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPServerRequestPart
|
||||
typealias OutboundOut = HTTPServerResponsePart
|
||||
|
||||
private let remoteHost: String
|
||||
private let remotePort: Int
|
||||
private let domain: String
|
||||
private let trafficRepo: TrafficRepository
|
||||
private var remoteChannel: Channel?
|
||||
|
||||
// Buffer request parts until upstream is connected
|
||||
private var pendingParts: [HTTPServerRequestPart] = []
|
||||
private var isConnected = false
|
||||
|
||||
init(remoteHost: String, remotePort: Int, domain: String, trafficRepo: TrafficRepository) {
|
||||
self.remoteHost = remoteHost
|
||||
self.remotePort = remotePort
|
||||
self.domain = domain
|
||||
self.trafficRepo = trafficRepo
|
||||
}
|
||||
|
||||
func handlerAdded(context: ChannelHandlerContext) {
|
||||
connectToRemote(context: context)
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
if isConnected, let remote = remoteChannel {
|
||||
// Forward to upstream as client request
|
||||
switch part {
|
||||
case .head(let head):
|
||||
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
|
||||
if !clientHead.headers.contains(name: "Host") {
|
||||
clientHead.headers.add(name: "Host", value: domain)
|
||||
}
|
||||
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
|
||||
case .body(let buffer):
|
||||
remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil)
|
||||
case .end(let trailers):
|
||||
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
|
||||
}
|
||||
} else {
|
||||
pendingParts.append(part)
|
||||
}
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
remoteChannel?.close(promise: nil)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
print("[MITMForward] Error: \(error)")
|
||||
context.close(promise: nil)
|
||||
remoteChannel?.close(promise: nil)
|
||||
}
|
||||
|
||||
private func connectToRemote(context: ChannelHandlerContext) {
|
||||
let captureHandler = HTTPCaptureHandler(trafficRepo: trafficRepo, domain: domain, scheme: "https")
|
||||
let clientContext = context
|
||||
|
||||
do {
|
||||
let tlsConfig = TLSConfiguration.makeClientConfiguration()
|
||||
let sslContext = try NIOSSLContext(configuration: tlsConfig)
|
||||
|
||||
ClientBootstrap(group: context.eventLoop)
|
||||
.channelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
let sniHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.domain)
|
||||
return channel.pipeline.addHandler(sniHandler).flatMap {
|
||||
channel.pipeline.addHandler(HTTPRequestEncoder())
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(captureHandler)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(
|
||||
MITMRelayHandler(clientContext: clientContext)
|
||||
)
|
||||
}
|
||||
}
|
||||
.connect(host: remoteHost, port: remotePort)
|
||||
.whenComplete { result in
|
||||
switch result {
|
||||
case .success(let channel):
|
||||
self.remoteChannel = channel
|
||||
self.isConnected = true
|
||||
self.flushPending(remote: channel)
|
||||
case .failure(let error):
|
||||
print("[MITMForward] Connect to \(self.remoteHost):\(self.remotePort) failed: \(error)")
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[MITMForward] TLS setup failed: \(error)")
|
||||
context.close(promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
private func flushPending(remote: Channel) {
|
||||
for part in pendingParts {
|
||||
switch part {
|
||||
case .head(let head):
|
||||
var clientHead = HTTPRequestHead(version: head.version, method: head.method, uri: head.uri, headers: head.headers)
|
||||
if !clientHead.headers.contains(name: "Host") {
|
||||
clientHead.headers.add(name: "Host", value: domain)
|
||||
}
|
||||
remote.write(NIOAny(HTTPClientRequestPart.head(clientHead)), promise: nil)
|
||||
case .body(let buffer):
|
||||
remote.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: nil)
|
||||
case .end(let trailers):
|
||||
remote.writeAndFlush(NIOAny(HTTPClientRequestPart.end(trailers)), promise: nil)
|
||||
}
|
||||
}
|
||||
pendingParts.removeAll()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - MITMRelayHandler
|
||||
|
||||
/// Relays responses from the real server back to the proxy client.
|
||||
final class MITMRelayHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = HTTPClientResponsePart
|
||||
|
||||
private let clientContext: ChannelHandlerContext
|
||||
|
||||
init(clientContext: ChannelHandlerContext) {
|
||||
self.clientContext = clientContext
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let part = unwrapInboundIn(data)
|
||||
|
||||
switch part {
|
||||
case .head(let head):
|
||||
let serverResponse = HTTPResponseHead(version: head.version, status: head.status, headers: head.headers)
|
||||
clientContext.write(NIOAny(HTTPServerResponsePart.head(serverResponse)), promise: nil)
|
||||
case .body(let buffer):
|
||||
clientContext.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil)
|
||||
case .end(let trailers):
|
||||
clientContext.writeAndFlush(NIOAny(HTTPServerResponsePart.end(trailers)), promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
func channelInactive(context: ChannelHandlerContext) {
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
|
||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
print("[MITMRelay] Error: \(error)")
|
||||
context.close(promise: nil)
|
||||
clientContext.close(promise: nil)
|
||||
}
|
||||
}
|
||||
55
ProxyCore/Sources/ProxyEngine/ProxyServer.swift
Normal file
55
ProxyCore/Sources/ProxyEngine/ProxyServer.swift
Normal file
@@ -0,0 +1,55 @@
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
import NIOHTTP1
|
||||
|
||||
public final class ProxyServer: Sendable {
|
||||
private let host: String
|
||||
private let port: Int
|
||||
private let group: EventLoopGroup
|
||||
private let trafficRepo: TrafficRepository
|
||||
nonisolated(unsafe) private var channel: Channel?
|
||||
|
||||
public init(
|
||||
host: String = ProxyConstants.proxyHost,
|
||||
port: Int = ProxyConstants.proxyPort,
|
||||
trafficRepo: TrafficRepository = TrafficRepository()
|
||||
) {
|
||||
self.host = host
|
||||
self.port = port
|
||||
// Use only 1 thread to conserve memory in the extension (50MB budget)
|
||||
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
self.trafficRepo = trafficRepo
|
||||
}
|
||||
|
||||
public func start() async throws {
|
||||
let trafficRepo = self.trafficRepo
|
||||
|
||||
let bootstrap = ServerBootstrap(group: group)
|
||||
.serverChannelOption(.backlog, value: 256)
|
||||
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.childChannelInitializer { channel in
|
||||
channel.pipeline.addHandler(
|
||||
ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
|
||||
).flatMap {
|
||||
channel.pipeline.addHandler(HTTPResponseEncoder())
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(ConnectHandler(trafficRepo: trafficRepo))
|
||||
}
|
||||
}
|
||||
.childChannelOption(.socketOption(.so_reuseaddr), value: 1)
|
||||
.childChannelOption(.maxMessagesPerRead, value: 16)
|
||||
|
||||
channel = try await bootstrap.bind(host: host, port: port).get()
|
||||
print("[ProxyServer] Listening on \(host):\(port)")
|
||||
}
|
||||
|
||||
public func stop() async {
|
||||
do {
|
||||
try await channel?.close()
|
||||
try await group.shutdownGracefully()
|
||||
} catch {
|
||||
print("[ProxyServer] Shutdown error: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
106
ProxyCore/Sources/Shared/CURLParser.swift
Normal file
106
ProxyCore/Sources/Shared/CURLParser.swift
Normal file
@@ -0,0 +1,106 @@
|
||||
import Foundation
|
||||
|
||||
public struct ParsedCURLRequest: Sendable {
|
||||
public var method: String = "GET"
|
||||
public var url: String = ""
|
||||
public var headers: [(key: String, value: String)] = []
|
||||
public var body: String?
|
||||
}
|
||||
|
||||
public enum CURLParser {
|
||||
public static func parse(_ curlString: String) -> ParsedCURLRequest? {
|
||||
var result = ParsedCURLRequest()
|
||||
let trimmed = curlString.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
guard trimmed.lowercased().hasPrefix("curl") else { return nil }
|
||||
|
||||
let tokens = tokenize(trimmed)
|
||||
var i = 0
|
||||
|
||||
while i < tokens.count {
|
||||
let token = tokens[i]
|
||||
|
||||
switch token {
|
||||
case "curl":
|
||||
break
|
||||
case "-X", "--request":
|
||||
i += 1
|
||||
if i < tokens.count {
|
||||
result.method = tokens[i].uppercased()
|
||||
}
|
||||
case "-H", "--header":
|
||||
i += 1
|
||||
if i < tokens.count {
|
||||
let header = tokens[i]
|
||||
if let colonIndex = header.firstIndex(of: ":") {
|
||||
let key = String(header[header.startIndex..<colonIndex]).trimmingCharacters(in: .whitespaces)
|
||||
let value = String(header[header.index(after: colonIndex)...]).trimmingCharacters(in: .whitespaces)
|
||||
result.headers.append((key: key, value: value))
|
||||
}
|
||||
}
|
||||
case "-d", "--data", "--data-raw", "--data-binary":
|
||||
i += 1
|
||||
if i < tokens.count {
|
||||
result.body = tokens[i]
|
||||
if result.method == "GET" {
|
||||
result.method = "POST"
|
||||
}
|
||||
}
|
||||
default:
|
||||
if !token.hasPrefix("-") && result.url.isEmpty {
|
||||
result.url = token
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
||||
return result.url.isEmpty ? nil : result
|
||||
}
|
||||
|
||||
private static func tokenize(_ input: String) -> [String] {
|
||||
var tokens: [String] = []
|
||||
var current = ""
|
||||
var inSingleQuote = false
|
||||
var inDoubleQuote = false
|
||||
var escaped = false
|
||||
|
||||
for char in input {
|
||||
if escaped {
|
||||
current.append(char)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if char == "\\" && !inSingleQuote {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if char == "'" && !inDoubleQuote {
|
||||
inSingleQuote.toggle()
|
||||
continue
|
||||
}
|
||||
|
||||
if char == "\"" && !inSingleQuote {
|
||||
inDoubleQuote.toggle()
|
||||
continue
|
||||
}
|
||||
|
||||
if char.isWhitespace && !inSingleQuote && !inDoubleQuote {
|
||||
if !current.isEmpty {
|
||||
tokens.append(current)
|
||||
current = ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
current.append(char)
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
tokens.append(current)
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
}
|
||||
18
ProxyCore/Sources/Shared/Constants.swift
Normal file
18
ProxyCore/Sources/Shared/Constants.swift
Normal file
@@ -0,0 +1,18 @@
|
||||
import Foundation
|
||||
|
||||
public enum ProxyConstants {
|
||||
public static let proxyHost = "127.0.0.1"
|
||||
public static let proxyPort: Int = 9090
|
||||
public static let appGroupIdentifier = "group.com.treyt.proxyapp"
|
||||
public static let extensionBundleIdentifier = "com.treyt.proxyapp.PacketTunnel"
|
||||
public static let maxBodySizeBytes = 1_048_576 // 1 MB - truncate larger bodies
|
||||
public static let certificateCacheSize = 500
|
||||
|
||||
public static let httpMethods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
|
||||
|
||||
public static let commonHeaders = [
|
||||
"Accept", "Accept-Charset", "Accept-Encoding", "Accept-Language",
|
||||
"Authorization", "Cache-Control", "Connection", "Content-Length",
|
||||
"Content-Type", "Cookie", "Host", "Origin", "Referer", "User-Agent"
|
||||
]
|
||||
}
|
||||
103
ProxyCore/Sources/Shared/IPCManager.swift
Normal file
103
ProxyCore/Sources/Shared/IPCManager.swift
Normal file
@@ -0,0 +1,103 @@
|
||||
import Foundation
|
||||
|
||||
/// Lightweight IPC between the main app and the packet tunnel extension
|
||||
/// using Darwin notifications (fire-and-forget signals) and shared UserDefaults.
|
||||
public final class IPCManager: Sendable {
|
||||
public static let shared = IPCManager()
|
||||
|
||||
private let suiteName = "group.com.treyt.proxyapp"
|
||||
|
||||
public enum Notification: String, Sendable {
|
||||
case newTrafficCaptured = "com.treyt.proxyapp.newTraffic"
|
||||
case configurationChanged = "com.treyt.proxyapp.configChanged"
|
||||
case extensionStarted = "com.treyt.proxyapp.extensionStarted"
|
||||
case extensionStopped = "com.treyt.proxyapp.extensionStopped"
|
||||
}
|
||||
|
||||
private init() {}
|
||||
|
||||
// MARK: - Darwin Notifications
|
||||
|
||||
public func post(_ notification: Notification) {
|
||||
let name = CFNotificationName(notification.rawValue as CFString)
|
||||
CFNotificationCenterPostNotification(
|
||||
CFNotificationCenterGetDarwinNotifyCenter(),
|
||||
name, nil, nil, true
|
||||
)
|
||||
}
|
||||
|
||||
public func observe(_ notification: Notification, callback: @escaping @Sendable () -> Void) {
|
||||
let name = notification.rawValue as CFString
|
||||
let center = CFNotificationCenterGetDarwinNotifyCenter()
|
||||
|
||||
// Store callback in a static dictionary keyed by notification name
|
||||
DarwinCallbackStore.shared.register(name: notification.rawValue, callback: callback)
|
||||
|
||||
CFNotificationCenterAddObserver(
|
||||
center, nil,
|
||||
{ _, _, name, _, _ in
|
||||
guard let cfName = name?.rawValue as? String else { return }
|
||||
DarwinCallbackStore.shared.fire(name: cfName)
|
||||
},
|
||||
name, nil,
|
||||
.deliverImmediately
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Shared UserDefaults
|
||||
|
||||
public var sharedDefaults: UserDefaults? {
|
||||
UserDefaults(suiteName: suiteName)
|
||||
}
|
||||
|
||||
public var isSSLProxyingEnabled: Bool {
|
||||
get { sharedDefaults?.bool(forKey: "sslProxyingEnabled") ?? false }
|
||||
set { sharedDefaults?.set(newValue, forKey: "sslProxyingEnabled") }
|
||||
}
|
||||
|
||||
public var isBlockListEnabled: Bool {
|
||||
get { sharedDefaults?.bool(forKey: "blockListEnabled") ?? false }
|
||||
set { sharedDefaults?.set(newValue, forKey: "blockListEnabled") }
|
||||
}
|
||||
|
||||
public var isBreakpointEnabled: Bool {
|
||||
get { sharedDefaults?.bool(forKey: "breakpointEnabled") ?? false }
|
||||
set { sharedDefaults?.set(newValue, forKey: "breakpointEnabled") }
|
||||
}
|
||||
|
||||
public var isNoCachingEnabled: Bool {
|
||||
get { sharedDefaults?.bool(forKey: "noCachingEnabled") ?? false }
|
||||
set { sharedDefaults?.set(newValue, forKey: "noCachingEnabled") }
|
||||
}
|
||||
|
||||
public var isDNSSpoofingEnabled: Bool {
|
||||
get { sharedDefaults?.bool(forKey: "dnsSpoofingEnabled") ?? false }
|
||||
set { sharedDefaults?.set(newValue, forKey: "dnsSpoofingEnabled") }
|
||||
}
|
||||
|
||||
public var hideSystemTraffic: Bool {
|
||||
get { sharedDefaults?.bool(forKey: "hideSystemTraffic") ?? false }
|
||||
set { sharedDefaults?.set(newValue, forKey: "hideSystemTraffic") }
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Darwin Callback Storage
|
||||
|
||||
private final class DarwinCallbackStore: @unchecked Sendable {
|
||||
static let shared = DarwinCallbackStore()
|
||||
private var callbacks: [String: @Sendable () -> Void] = [:]
|
||||
private let lock = NSLock()
|
||||
|
||||
func register(name: String, callback: @escaping @Sendable () -> Void) {
|
||||
lock.lock()
|
||||
callbacks[name] = callback
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
func fire(name: String) {
|
||||
lock.lock()
|
||||
let cb = callbacks[name]
|
||||
lock.unlock()
|
||||
cb?()
|
||||
}
|
||||
}
|
||||
40
ProxyCore/Sources/Shared/WildcardMatcher.swift
Normal file
40
ProxyCore/Sources/Shared/WildcardMatcher.swift
Normal file
@@ -0,0 +1,40 @@
|
||||
import Foundation
|
||||
|
||||
public enum WildcardMatcher {
|
||||
/// Matches a string against a glob pattern with `*` (zero or more chars) and `?` (single char).
|
||||
public static func matches(_ string: String, pattern: String) -> Bool {
|
||||
let s = Array(string.lowercased())
|
||||
let p = Array(pattern.lowercased())
|
||||
return matchHelper(s, 0, p, 0)
|
||||
}
|
||||
|
||||
private static func matchHelper(_ s: [Character], _ si: Int, _ p: [Character], _ pi: Int) -> Bool {
|
||||
var si = si
|
||||
var pi = pi
|
||||
var starIdx = -1
|
||||
var matchIdx = 0
|
||||
|
||||
while si < s.count {
|
||||
if pi < p.count && (p[pi] == "?" || p[pi] == s[si]) {
|
||||
si += 1
|
||||
pi += 1
|
||||
} else if pi < p.count && p[pi] == "*" {
|
||||
starIdx = pi
|
||||
matchIdx = si
|
||||
pi += 1
|
||||
} else if starIdx != -1 {
|
||||
pi = starIdx + 1
|
||||
matchIdx += 1
|
||||
si = matchIdx
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
while pi < p.count && p[pi] == "*" {
|
||||
pi += 1
|
||||
}
|
||||
|
||||
return pi == p.count
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user