"""CloudKit Web Services client for sportstime-parser. This module provides a client for uploading data to CloudKit using the CloudKit Web Services API. It handles JWT authentication, request signing, and batch operations. Reference: https://developer.apple.com/documentation/cloudkitwebservices """ import base64 import hashlib import json import os import time from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Optional from enum import Enum import jwt import requests from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.backends import default_backend from ..config import ( CLOUDKIT_CONTAINER_ID, CLOUDKIT_ENVIRONMENT, CLOUDKIT_BATCH_SIZE, CLOUDKIT_KEY_ID, CLOUDKIT_PRIVATE_KEY_PATH, ) from ..utils.logging import get_logger class RecordType(str, Enum): """CloudKit record types for SportsTime. Must match CKRecordType constants in CKModels.swift. """ GAME = "Game" TEAM = "Team" STADIUM = "Stadium" TEAM_ALIAS = "TeamAlias" STADIUM_ALIAS = "StadiumAlias" SPORT = "Sport" LEAGUE_STRUCTURE = "LeagueStructure" TRIP_POLL = "TripPoll" POLL_VOTE = "PollVote" ITINERARY_ITEM = "ItineraryItem" @dataclass class CloudKitRecord: """Represents a CloudKit record for upload. Attributes: record_name: Unique record identifier (canonical ID) record_type: CloudKit record type fields: Dictionary of field name -> field value record_change_tag: Version tag for conflict detection (None for new records) """ record_name: str record_type: RecordType fields: dict[str, Any] record_change_tag: Optional[str] = None def to_cloudkit_dict(self) -> dict: """Convert to CloudKit API format.""" record = { "recordName": self.record_name, "recordType": self.record_type.value, "fields": self._format_fields(), } if self.record_change_tag: record["recordChangeTag"] = self.record_change_tag return record def _format_fields(self) -> dict: """Format fields for CloudKit API.""" formatted = {} for key, value in self.fields.items(): if value is None: continue formatted[key] = self._format_field_value(value) return formatted def _format_field_value(self, value: Any) -> dict: """Format a single field value for CloudKit API.""" # Check bool BEFORE int (bool is a subclass of int in Python) if isinstance(value, bool): return {"value": 1 if value else 0, "type": "INT64"} elif isinstance(value, str): return {"value": value, "type": "STRING"} elif isinstance(value, int): return {"value": value, "type": "INT64"} elif isinstance(value, float): return {"value": value, "type": "DOUBLE"} elif isinstance(value, datetime): # CloudKit expects milliseconds since epoch timestamp_ms = int(value.timestamp() * 1000) return {"value": timestamp_ms, "type": "TIMESTAMP"} elif isinstance(value, list): return {"value": value, "type": "STRING_LIST"} elif isinstance(value, dict) and "latitude" in value and "longitude" in value: return { "value": { "latitude": value["latitude"], "longitude": value["longitude"], }, "type": "LOCATION", } else: # Default to string return {"value": str(value), "type": "STRING"} @dataclass class OperationResult: """Result of a CloudKit operation.""" record_name: str success: bool record_change_tag: Optional[str] = None error_code: Optional[str] = None error_message: Optional[str] = None @dataclass class BatchResult: """Result of a batch CloudKit operation.""" successful: list[OperationResult] = field(default_factory=list) failed: list[OperationResult] = field(default_factory=list) @property def all_succeeded(self) -> bool: return len(self.failed) == 0 @property def success_count(self) -> int: return len(self.successful) @property def failure_count(self) -> int: return len(self.failed) class CloudKitClient: """Client for CloudKit Web Services API. Handles authentication via server-to-server JWT tokens and provides methods for CRUD operations on CloudKit records. Authentication requires: - Key ID: CloudKit key identifier from Apple Developer Portal - Private Key: EC private key in PEM format Environment variables: - CLOUDKIT_KEY_ID: The key identifier - CLOUDKIT_PRIVATE_KEY_PATH: Path to the private key file - CLOUDKIT_PRIVATE_KEY: The private key contents (alternative to path) """ BASE_URL = "https://api.apple-cloudkit.com" TOKEN_EXPIRY_SECONDS = 3600 # 1 hour def __init__( self, container_id: str = CLOUDKIT_CONTAINER_ID, environment: str = CLOUDKIT_ENVIRONMENT, key_id: Optional[str] = None, private_key: Optional[str] = None, private_key_path: Optional[str] = None, ): """Initialize the CloudKit client. Args: container_id: CloudKit container identifier environment: 'development' or 'production' key_id: CloudKit server-to-server key ID private_key: PEM-encoded EC private key contents private_key_path: Path to PEM-encoded EC private key file """ self.container_id = container_id self.environment = environment self.logger = get_logger() # Load authentication credentials (config defaults > env vars > None) self.key_id = key_id or os.environ.get("CLOUDKIT_KEY_ID") or CLOUDKIT_KEY_ID if private_key: self._private_key_pem = private_key elif private_key_path: self._private_key_pem = Path(private_key_path).read_text() elif os.environ.get("CLOUDKIT_PRIVATE_KEY"): self._private_key_pem = os.environ["CLOUDKIT_PRIVATE_KEY"] elif os.environ.get("CLOUDKIT_PRIVATE_KEY_PATH"): self._private_key_pem = Path(os.environ["CLOUDKIT_PRIVATE_KEY_PATH"]).read_text() elif CLOUDKIT_PRIVATE_KEY_PATH.exists(): self._private_key_pem = CLOUDKIT_PRIVATE_KEY_PATH.read_text() else: self._private_key_pem = None # Parse the private key if available self._private_key = None if self._private_key_pem: self._private_key = serialization.load_pem_private_key( self._private_key_pem.encode(), password=None, backend=default_backend(), ) # Token cache self._token: Optional[str] = None self._token_expiry: float = 0 # Session for connection pooling self._session = requests.Session() @property def is_configured(self) -> bool: """Check if the client has valid authentication credentials.""" return bool(self.key_id and self._private_key) def _get_api_path(self, operation: str) -> str: """Build the full API path for an operation.""" return f"/database/1/{self.container_id}/{self.environment}/public/{operation}" def _get_token(self) -> str: """Get a valid JWT token, generating a new one if needed.""" if not self.is_configured: raise ValueError( "CloudKit client not configured. Set CLOUDKIT_KEY_ID and " "CLOUDKIT_PRIVATE_KEY_PATH environment variables." ) now = time.time() # Return cached token if still valid (with 5 min buffer) if self._token and (self._token_expiry - now) > 300: return self._token # Generate new token expiry = now + self.TOKEN_EXPIRY_SECONDS payload = { "iss": self.key_id, "iat": int(now), "exp": int(expiry), "sub": self.container_id, } self._token = jwt.encode( payload, self._private_key, algorithm="ES256", ) self._token_expiry = expiry return self._token def _sign_request(self, method: str, path: str, body: Optional[bytes] = None) -> dict: """Generate request headers with authentication. Args: method: HTTP method path: API path body: Request body bytes Returns: Dictionary of headers to include in the request """ token = self._get_token() # CloudKit uses date in ISO format date_str = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") # Calculate body hash if body: body_hash = base64.b64encode( hashlib.sha256(body).digest() ).decode() else: body_hash = base64.b64encode( hashlib.sha256(b"").digest() ).decode() # Build the message to sign # Format: date:body_hash:path message = f"{date_str}:{body_hash}:{path}" # Sign the message signature = self._private_key.sign( message.encode(), ec.ECDSA(hashes.SHA256()), ) signature_b64 = base64.b64encode(signature).decode() return { "Authorization": f"Bearer {token}", "X-Apple-CloudKit-Request-KeyID": self.key_id, "X-Apple-CloudKit-Request-ISO8601Date": date_str, "X-Apple-CloudKit-Request-SignatureV1": signature_b64, "Content-Type": "application/json", } def _request( self, method: str, operation: str, body: Optional[dict] = None, ) -> dict: """Make a request to the CloudKit API. Args: method: HTTP method operation: API operation path body: Request body as dictionary Returns: Response data as dictionary Raises: CloudKitError: If the request fails """ path = self._get_api_path(operation) url = f"{self.BASE_URL}{path}" body_bytes = json.dumps(body).encode() if body else None headers = self._sign_request(method, path, body_bytes) response = self._session.request( method=method, url=url, headers=headers, data=body_bytes, ) if response.status_code == 200: return response.json() elif response.status_code == 421: # Authentication required - token may be expired self._token = None raise CloudKitAuthError("Authentication failed - check credentials") elif response.status_code == 429: raise CloudKitRateLimitError("Rate limit exceeded") elif response.status_code >= 500: raise CloudKitServerError(f"Server error: {response.status_code}") else: try: error_data = response.json() error_msg = error_data.get("serverErrorCode", str(response.status_code)) except (json.JSONDecodeError, KeyError): error_msg = response.text raise CloudKitError(f"Request failed: {error_msg}") def fetch_records( self, record_type: RecordType, record_names: Optional[list[str]] = None, limit: int = 200, ) -> list[dict]: """Fetch records from CloudKit. Args: record_type: Type of records to fetch record_names: Specific record names to fetch (optional) limit: Maximum records to return (default 200) Returns: List of record dictionaries """ if record_names: # Fetch specific records by name body = { "records": [{"recordName": name} for name in record_names], } response = self._request("POST", "records/lookup", body) else: # Query all records of type body = { "query": { "recordType": record_type.value, }, "resultsLimit": limit, } response = self._request("POST", "records/query", body) records = response.get("records", []) return [r for r in records if "recordName" in r] def fetch_all_records(self, record_type: RecordType) -> list[dict]: """Fetch all records of a type using pagination. Args: record_type: Type of records to fetch Returns: List of all record dictionaries """ all_records = [] continuation_marker = None while True: body = { "query": { "recordType": record_type.value, }, "resultsLimit": 200, } if continuation_marker: body["continuationMarker"] = continuation_marker response = self._request("POST", "records/query", body) records = response.get("records", []) all_records.extend([r for r in records if "recordName" in r]) continuation_marker = response.get("continuationMarker") if not continuation_marker: break return all_records def save_records(self, records: list[CloudKitRecord]) -> BatchResult: """Save records to CloudKit (create or update). Args: records: List of records to save Returns: BatchResult with success/failure details """ result = BatchResult() # Process in batches for i in range(0, len(records), CLOUDKIT_BATCH_SIZE): batch = records[i:i + CLOUDKIT_BATCH_SIZE] batch_result = self._save_batch(batch) result.successful.extend(batch_result.successful) result.failed.extend(batch_result.failed) return result def _save_batch(self, records: list[CloudKitRecord]) -> BatchResult: """Save a single batch of records. Args: records: List of records (max CLOUDKIT_BATCH_SIZE) Returns: BatchResult with success/failure details """ result = BatchResult() operations = [] for record in records: op = { "operationType": "forceReplace", "record": record.to_cloudkit_dict(), } operations.append(op) body = {"operations": operations} try: response = self._request("POST", "records/modify", body) except CloudKitError as e: # Entire batch failed for record in records: result.failed.append(OperationResult( record_name=record.record_name, success=False, error_message=str(e), )) return result # Process individual results for record_data in response.get("records", []): record_name = record_data.get("recordName", "unknown") if "serverErrorCode" in record_data: result.failed.append(OperationResult( record_name=record_name, success=False, error_code=record_data.get("serverErrorCode"), error_message=record_data.get("reason"), )) else: result.successful.append(OperationResult( record_name=record_name, success=True, record_change_tag=record_data.get("recordChangeTag"), )) return result def delete_records( self, record_type: RecordType, records: list[dict], ) -> BatchResult: """Delete records from CloudKit. Args: record_type: Type of records to delete records: List of record dicts (must have recordName and recordChangeTag) Returns: BatchResult with success/failure details """ result = BatchResult() # Process in batches for i in range(0, len(records), CLOUDKIT_BATCH_SIZE): batch = records[i:i + CLOUDKIT_BATCH_SIZE] operations = [] for record in batch: operations.append({ "operationType": "delete", "record": { "recordName": record["recordName"], "recordChangeTag": record.get("recordChangeTag"), }, }) body = {"operations": operations} try: response = self._request("POST", "records/modify", body) except CloudKitError as e: for record in batch: result.failed.append(OperationResult( record_name=record["recordName"], success=False, error_message=str(e), )) continue for record_data in response.get("records", []): record_name = record_data.get("recordName", "unknown") if "serverErrorCode" in record_data: result.failed.append(OperationResult( record_name=record_name, success=False, error_code=record_data.get("serverErrorCode"), error_message=record_data.get("reason"), )) else: result.successful.append(OperationResult( record_name=record_name, success=True, )) return result class CloudKitError(Exception): """Base exception for CloudKit errors.""" pass class CloudKitAuthError(CloudKitError): """Authentication error.""" pass class CloudKitRateLimitError(CloudKitError): """Rate limit exceeded.""" pass class CloudKitServerError(CloudKitError): """Server-side error.""" pass