feat(scripts): add sportstime-parser data pipeline
Complete Python package for scraping, normalizing, and uploading sports schedule data to CloudKit. Includes: - Multi-source scrapers for NBA, MLB, NFL, NHL, MLS, WNBA, NWSL - Canonical ID system for teams, stadiums, and games - Fuzzy matching with manual alias support - CloudKit uploader with batch operations and deduplication - Comprehensive test suite with fixtures - WNBA abbreviation aliases for improved team resolution - Alias validation script to detect orphan references All 5 phases of data remediation plan completed: - Phase 1: Alias fixes (team/stadium alias additions) - Phase 2: NHL stadium coordinate fixes - Phase 3: Re-scrape validation - Phase 4: iOS bundle update - Phase 5: Code quality improvements (WNBA aliases) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
1
sportstime_parser/tests/test_uploaders/__init__.py
Normal file
1
sportstime_parser/tests/test_uploaders/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the uploaders module."""
|
||||
461
sportstime_parser/tests/test_uploaders/test_cloudkit.py
Normal file
461
sportstime_parser/tests/test_uploaders/test_cloudkit.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""Tests for the CloudKit client."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from sportstime_parser.uploaders.cloudkit import (
|
||||
CloudKitClient,
|
||||
CloudKitRecord,
|
||||
CloudKitError,
|
||||
CloudKitAuthError,
|
||||
CloudKitRateLimitError,
|
||||
CloudKitServerError,
|
||||
RecordType,
|
||||
OperationResult,
|
||||
BatchResult,
|
||||
)
|
||||
|
||||
|
||||
class TestCloudKitRecord:
|
||||
"""Tests for CloudKitRecord dataclass."""
|
||||
|
||||
def test_create_record(self):
|
||||
"""Test creating a CloudKitRecord."""
|
||||
record = CloudKitRecord(
|
||||
record_name="nba_2025_hou_okc_1021",
|
||||
record_type=RecordType.GAME,
|
||||
fields={
|
||||
"sport": "nba",
|
||||
"season": 2025,
|
||||
},
|
||||
)
|
||||
|
||||
assert record.record_name == "nba_2025_hou_okc_1021"
|
||||
assert record.record_type == RecordType.GAME
|
||||
assert record.fields["sport"] == "nba"
|
||||
assert record.record_change_tag is None
|
||||
|
||||
def test_to_cloudkit_dict(self):
|
||||
"""Test converting to CloudKit API format."""
|
||||
record = CloudKitRecord(
|
||||
record_name="nba_2025_hou_okc_1021",
|
||||
record_type=RecordType.GAME,
|
||||
fields={
|
||||
"sport": "nba",
|
||||
"season": 2025,
|
||||
},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert data["recordName"] == "nba_2025_hou_okc_1021"
|
||||
assert data["recordType"] == "Game"
|
||||
assert "fields" in data
|
||||
assert "recordChangeTag" not in data
|
||||
|
||||
def test_to_cloudkit_dict_with_change_tag(self):
|
||||
"""Test converting with change tag for updates."""
|
||||
record = CloudKitRecord(
|
||||
record_name="nba_2025_hou_okc_1021",
|
||||
record_type=RecordType.GAME,
|
||||
fields={"sport": "nba"},
|
||||
record_change_tag="abc123",
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert data["recordChangeTag"] == "abc123"
|
||||
|
||||
def test_format_string_field(self):
|
||||
"""Test formatting string fields."""
|
||||
record = CloudKitRecord(
|
||||
record_name="test",
|
||||
record_type=RecordType.GAME,
|
||||
fields={"name": "Test Name"},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert data["fields"]["name"]["value"] == "Test Name"
|
||||
assert data["fields"]["name"]["type"] == "STRING"
|
||||
|
||||
def test_format_int_field(self):
|
||||
"""Test formatting integer fields."""
|
||||
record = CloudKitRecord(
|
||||
record_name="test",
|
||||
record_type=RecordType.GAME,
|
||||
fields={"count": 42},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert data["fields"]["count"]["value"] == 42
|
||||
assert data["fields"]["count"]["type"] == "INT64"
|
||||
|
||||
def test_format_float_field(self):
|
||||
"""Test formatting float fields."""
|
||||
record = CloudKitRecord(
|
||||
record_name="test",
|
||||
record_type=RecordType.STADIUM,
|
||||
fields={"latitude": 35.4634},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert data["fields"]["latitude"]["value"] == 35.4634
|
||||
assert data["fields"]["latitude"]["type"] == "DOUBLE"
|
||||
|
||||
def test_format_datetime_field(self):
|
||||
"""Test formatting datetime fields."""
|
||||
dt = datetime(2025, 10, 21, 19, 0, 0)
|
||||
record = CloudKitRecord(
|
||||
record_name="test",
|
||||
record_type=RecordType.GAME,
|
||||
fields={"game_date": dt},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
expected_ms = int(dt.timestamp() * 1000)
|
||||
assert data["fields"]["game_date"]["value"] == expected_ms
|
||||
assert data["fields"]["game_date"]["type"] == "TIMESTAMP"
|
||||
|
||||
def test_format_location_field(self):
|
||||
"""Test formatting location fields."""
|
||||
record = CloudKitRecord(
|
||||
record_name="test",
|
||||
record_type=RecordType.STADIUM,
|
||||
fields={
|
||||
"location": {"latitude": 35.4634, "longitude": -97.5151},
|
||||
},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert data["fields"]["location"]["type"] == "LOCATION"
|
||||
assert data["fields"]["location"]["value"]["latitude"] == 35.4634
|
||||
assert data["fields"]["location"]["value"]["longitude"] == -97.5151
|
||||
|
||||
def test_skip_none_fields(self):
|
||||
"""Test that None fields are skipped."""
|
||||
record = CloudKitRecord(
|
||||
record_name="test",
|
||||
record_type=RecordType.GAME,
|
||||
fields={
|
||||
"sport": "nba",
|
||||
"score": None, # Should be skipped
|
||||
},
|
||||
)
|
||||
|
||||
data = record.to_cloudkit_dict()
|
||||
|
||||
assert "sport" in data["fields"]
|
||||
assert "score" not in data["fields"]
|
||||
|
||||
|
||||
class TestOperationResult:
|
||||
"""Tests for OperationResult dataclass."""
|
||||
|
||||
def test_successful_result(self):
|
||||
"""Test creating a successful operation result."""
|
||||
result = OperationResult(
|
||||
record_name="test_record",
|
||||
success=True,
|
||||
record_change_tag="new_tag",
|
||||
)
|
||||
|
||||
assert result.record_name == "test_record"
|
||||
assert result.success is True
|
||||
assert result.record_change_tag == "new_tag"
|
||||
assert result.error_code is None
|
||||
|
||||
def test_failed_result(self):
|
||||
"""Test creating a failed operation result."""
|
||||
result = OperationResult(
|
||||
record_name="test_record",
|
||||
success=False,
|
||||
error_code="SERVER_ERROR",
|
||||
error_message="Internal server error",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_code == "SERVER_ERROR"
|
||||
assert result.error_message == "Internal server error"
|
||||
|
||||
|
||||
class TestBatchResult:
|
||||
"""Tests for BatchResult dataclass."""
|
||||
|
||||
def test_empty_batch_result(self):
|
||||
"""Test empty batch result."""
|
||||
result = BatchResult()
|
||||
|
||||
assert result.all_succeeded is True
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 0
|
||||
|
||||
def test_batch_with_successes(self):
|
||||
"""Test batch with successful operations."""
|
||||
result = BatchResult()
|
||||
result.successful.append(OperationResult("rec1", True))
|
||||
result.successful.append(OperationResult("rec2", True))
|
||||
|
||||
assert result.all_succeeded is True
|
||||
assert result.success_count == 2
|
||||
assert result.failure_count == 0
|
||||
|
||||
def test_batch_with_failures(self):
|
||||
"""Test batch with failed operations."""
|
||||
result = BatchResult()
|
||||
result.successful.append(OperationResult("rec1", True))
|
||||
result.failed.append(OperationResult("rec2", False, error_message="Error"))
|
||||
|
||||
assert result.all_succeeded is False
|
||||
assert result.success_count == 1
|
||||
assert result.failure_count == 1
|
||||
|
||||
|
||||
class TestCloudKitClient:
|
||||
"""Tests for CloudKitClient."""
|
||||
|
||||
def test_not_configured_without_credentials(self):
|
||||
"""Test that client reports not configured without credentials."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
client = CloudKitClient()
|
||||
assert client.is_configured is False
|
||||
|
||||
def test_configured_with_credentials(self):
|
||||
"""Test that client reports configured with credentials."""
|
||||
# Create a minimal mock for the private key
|
||||
mock_key = MagicMock()
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key_id",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
mock_load.return_value = mock_key
|
||||
client = CloudKitClient()
|
||||
assert client.is_configured is True
|
||||
|
||||
def test_get_api_path(self):
|
||||
"""Test API path construction."""
|
||||
client = CloudKitClient(
|
||||
container_id="iCloud.com.test.app",
|
||||
environment="development",
|
||||
)
|
||||
|
||||
path = client._get_api_path("records/query")
|
||||
|
||||
assert path == "/database/1/iCloud.com.test.app/development/public/records/query"
|
||||
|
||||
@patch("sportstime_parser.uploaders.cloudkit.requests.Session")
|
||||
def test_fetch_records_query(self, mock_session_class):
|
||||
"""Test fetching records with query."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"records": [
|
||||
{"recordName": "rec1", "recordType": "Game"},
|
||||
{"recordName": "rec2", "recordType": "Game"},
|
||||
]
|
||||
}
|
||||
mock_session.request.return_value = mock_response
|
||||
|
||||
# Setup client with mocked auth
|
||||
mock_key = MagicMock()
|
||||
mock_key.sign.return_value = b"signature"
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
with patch("sportstime_parser.uploaders.cloudkit.jwt.encode") as mock_jwt:
|
||||
mock_load.return_value = mock_key
|
||||
mock_jwt.return_value = "test_token"
|
||||
|
||||
client = CloudKitClient()
|
||||
records = client.fetch_records(RecordType.GAME)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0]["recordName"] == "rec1"
|
||||
|
||||
@patch("sportstime_parser.uploaders.cloudkit.requests.Session")
|
||||
def test_save_records_success(self, mock_session_class):
|
||||
"""Test saving records successfully."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"records": [
|
||||
{"recordName": "rec1", "recordChangeTag": "tag1"},
|
||||
{"recordName": "rec2", "recordChangeTag": "tag2"},
|
||||
]
|
||||
}
|
||||
mock_session.request.return_value = mock_response
|
||||
|
||||
mock_key = MagicMock()
|
||||
mock_key.sign.return_value = b"signature"
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
with patch("sportstime_parser.uploaders.cloudkit.jwt.encode") as mock_jwt:
|
||||
mock_load.return_value = mock_key
|
||||
mock_jwt.return_value = "test_token"
|
||||
|
||||
client = CloudKitClient()
|
||||
|
||||
records = [
|
||||
CloudKitRecord("rec1", RecordType.GAME, {"sport": "nba"}),
|
||||
CloudKitRecord("rec2", RecordType.GAME, {"sport": "nba"}),
|
||||
]
|
||||
|
||||
result = client.save_records(records)
|
||||
|
||||
assert result.success_count == 2
|
||||
assert result.failure_count == 0
|
||||
|
||||
@patch("sportstime_parser.uploaders.cloudkit.requests.Session")
|
||||
def test_save_records_partial_failure(self, mock_session_class):
|
||||
"""Test saving records with some failures."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"records": [
|
||||
{"recordName": "rec1", "recordChangeTag": "tag1"},
|
||||
{"recordName": "rec2", "serverErrorCode": "QUOTA_EXCEEDED", "reason": "Quota exceeded"},
|
||||
]
|
||||
}
|
||||
mock_session.request.return_value = mock_response
|
||||
|
||||
mock_key = MagicMock()
|
||||
mock_key.sign.return_value = b"signature"
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
with patch("sportstime_parser.uploaders.cloudkit.jwt.encode") as mock_jwt:
|
||||
mock_load.return_value = mock_key
|
||||
mock_jwt.return_value = "test_token"
|
||||
|
||||
client = CloudKitClient()
|
||||
|
||||
records = [
|
||||
CloudKitRecord("rec1", RecordType.GAME, {"sport": "nba"}),
|
||||
CloudKitRecord("rec2", RecordType.GAME, {"sport": "nba"}),
|
||||
]
|
||||
|
||||
result = client.save_records(records)
|
||||
|
||||
assert result.success_count == 1
|
||||
assert result.failure_count == 1
|
||||
assert result.failed[0].error_code == "QUOTA_EXCEEDED"
|
||||
|
||||
@patch("sportstime_parser.uploaders.cloudkit.requests.Session")
|
||||
def test_auth_error(self, mock_session_class):
|
||||
"""Test handling authentication error."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 421
|
||||
mock_session.request.return_value = mock_response
|
||||
|
||||
mock_key = MagicMock()
|
||||
mock_key.sign.return_value = b"signature"
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
with patch("sportstime_parser.uploaders.cloudkit.jwt.encode") as mock_jwt:
|
||||
mock_load.return_value = mock_key
|
||||
mock_jwt.return_value = "test_token"
|
||||
|
||||
client = CloudKitClient()
|
||||
|
||||
with pytest.raises(CloudKitAuthError):
|
||||
client.fetch_records(RecordType.GAME)
|
||||
|
||||
@patch("sportstime_parser.uploaders.cloudkit.requests.Session")
|
||||
def test_rate_limit_error(self, mock_session_class):
|
||||
"""Test handling rate limit error."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
mock_session.request.return_value = mock_response
|
||||
|
||||
mock_key = MagicMock()
|
||||
mock_key.sign.return_value = b"signature"
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
with patch("sportstime_parser.uploaders.cloudkit.jwt.encode") as mock_jwt:
|
||||
mock_load.return_value = mock_key
|
||||
mock_jwt.return_value = "test_token"
|
||||
|
||||
client = CloudKitClient()
|
||||
|
||||
with pytest.raises(CloudKitRateLimitError):
|
||||
client.fetch_records(RecordType.GAME)
|
||||
|
||||
@patch("sportstime_parser.uploaders.cloudkit.requests.Session")
|
||||
def test_server_error(self, mock_session_class):
|
||||
"""Test handling server error."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 503
|
||||
mock_session.request.return_value = mock_response
|
||||
|
||||
mock_key = MagicMock()
|
||||
mock_key.sign.return_value = b"signature"
|
||||
|
||||
with patch.dict("os.environ", {
|
||||
"CLOUDKIT_KEY_ID": "test_key",
|
||||
"CLOUDKIT_PRIVATE_KEY": "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
}):
|
||||
with patch("sportstime_parser.uploaders.cloudkit.serialization.load_pem_private_key") as mock_load:
|
||||
with patch("sportstime_parser.uploaders.cloudkit.jwt.encode") as mock_jwt:
|
||||
mock_load.return_value = mock_key
|
||||
mock_jwt.return_value = "test_token"
|
||||
|
||||
client = CloudKitClient()
|
||||
|
||||
with pytest.raises(CloudKitServerError):
|
||||
client.fetch_records(RecordType.GAME)
|
||||
|
||||
|
||||
class TestRecordType:
|
||||
"""Tests for RecordType enum."""
|
||||
|
||||
def test_record_type_values(self):
|
||||
"""Test that record type values match CloudKit schema."""
|
||||
assert RecordType.GAME.value == "Game"
|
||||
assert RecordType.TEAM.value == "Team"
|
||||
assert RecordType.STADIUM.value == "Stadium"
|
||||
assert RecordType.TEAM_ALIAS.value == "TeamAlias"
|
||||
assert RecordType.STADIUM_ALIAS.value == "StadiumAlias"
|
||||
350
sportstime_parser/tests/test_uploaders/test_diff.py
Normal file
350
sportstime_parser/tests/test_uploaders/test_diff.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Tests for the record differ."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from sportstime_parser.models.game import Game
|
||||
from sportstime_parser.models.team import Team
|
||||
from sportstime_parser.models.stadium import Stadium
|
||||
from sportstime_parser.uploaders.diff import (
|
||||
DiffAction,
|
||||
RecordDiff,
|
||||
DiffResult,
|
||||
RecordDiffer,
|
||||
game_to_cloudkit_record,
|
||||
team_to_cloudkit_record,
|
||||
stadium_to_cloudkit_record,
|
||||
)
|
||||
from sportstime_parser.uploaders.cloudkit import RecordType
|
||||
|
||||
|
||||
class TestRecordDiff:
|
||||
"""Tests for RecordDiff dataclass."""
|
||||
|
||||
def test_create_record_diff(self):
|
||||
"""Test creating a RecordDiff."""
|
||||
diff = RecordDiff(
|
||||
record_name="nba_2025_hou_okc_1021",
|
||||
record_type=RecordType.GAME,
|
||||
action=DiffAction.CREATE,
|
||||
)
|
||||
|
||||
assert diff.record_name == "nba_2025_hou_okc_1021"
|
||||
assert diff.record_type == RecordType.GAME
|
||||
assert diff.action == DiffAction.CREATE
|
||||
|
||||
|
||||
class TestDiffResult:
|
||||
"""Tests for DiffResult dataclass."""
|
||||
|
||||
def test_empty_result(self):
|
||||
"""Test empty DiffResult."""
|
||||
result = DiffResult()
|
||||
|
||||
assert result.create_count == 0
|
||||
assert result.update_count == 0
|
||||
assert result.delete_count == 0
|
||||
assert result.unchanged_count == 0
|
||||
assert result.total_changes == 0
|
||||
|
||||
def test_counts(self):
|
||||
"""Test counting different change types."""
|
||||
result = DiffResult()
|
||||
|
||||
result.creates.append(RecordDiff(
|
||||
record_name="game_1",
|
||||
record_type=RecordType.GAME,
|
||||
action=DiffAction.CREATE,
|
||||
))
|
||||
result.creates.append(RecordDiff(
|
||||
record_name="game_2",
|
||||
record_type=RecordType.GAME,
|
||||
action=DiffAction.CREATE,
|
||||
))
|
||||
result.updates.append(RecordDiff(
|
||||
record_name="game_3",
|
||||
record_type=RecordType.GAME,
|
||||
action=DiffAction.UPDATE,
|
||||
))
|
||||
result.deletes.append(RecordDiff(
|
||||
record_name="game_4",
|
||||
record_type=RecordType.GAME,
|
||||
action=DiffAction.DELETE,
|
||||
))
|
||||
result.unchanged.append(RecordDiff(
|
||||
record_name="game_5",
|
||||
record_type=RecordType.GAME,
|
||||
action=DiffAction.UNCHANGED,
|
||||
))
|
||||
|
||||
assert result.create_count == 2
|
||||
assert result.update_count == 1
|
||||
assert result.delete_count == 1
|
||||
assert result.unchanged_count == 1
|
||||
assert result.total_changes == 4 # excludes unchanged
|
||||
|
||||
|
||||
class TestRecordDiffer:
|
||||
"""Tests for RecordDiffer."""
|
||||
|
||||
@pytest.fixture
|
||||
def differ(self):
|
||||
"""Create a RecordDiffer instance."""
|
||||
return RecordDiffer()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_game(self):
|
||||
"""Create a sample Game."""
|
||||
return Game(
|
||||
id="nba_2025_hou_okc_1021",
|
||||
sport="nba",
|
||||
season=2025,
|
||||
home_team_id="team_nba_okc",
|
||||
away_team_id="team_nba_hou",
|
||||
stadium_id="stadium_nba_paycom_center",
|
||||
game_date=datetime(2025, 10, 21, 19, 0, 0),
|
||||
status="scheduled",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_team(self):
|
||||
"""Create a sample Team."""
|
||||
return Team(
|
||||
id="team_nba_okc",
|
||||
sport="nba",
|
||||
city="Oklahoma City",
|
||||
name="Thunder",
|
||||
full_name="Oklahoma City Thunder",
|
||||
abbreviation="OKC",
|
||||
conference="Western",
|
||||
division="Northwest",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_stadium(self):
|
||||
"""Create a sample Stadium."""
|
||||
return Stadium(
|
||||
id="stadium_nba_paycom_center",
|
||||
sport="nba",
|
||||
name="Paycom Center",
|
||||
city="Oklahoma City",
|
||||
state="OK",
|
||||
country="USA",
|
||||
latitude=35.4634,
|
||||
longitude=-97.5151,
|
||||
capacity=18203,
|
||||
)
|
||||
|
||||
def test_diff_games_create(self, differ, sample_game):
|
||||
"""Test detecting new games to create."""
|
||||
local_games = [sample_game]
|
||||
remote_records = []
|
||||
|
||||
result = differ.diff_games(local_games, remote_records)
|
||||
|
||||
assert result.create_count == 1
|
||||
assert result.update_count == 0
|
||||
assert result.delete_count == 0
|
||||
assert result.creates[0].record_name == sample_game.id
|
||||
|
||||
def test_diff_games_delete(self, differ, sample_game):
|
||||
"""Test detecting games to delete."""
|
||||
local_games = []
|
||||
remote_records = [
|
||||
{
|
||||
"recordName": sample_game.id,
|
||||
"recordType": "Game",
|
||||
"fields": {
|
||||
"sport": {"value": "nba", "type": "STRING"},
|
||||
"season": {"value": 2025, "type": "INT64"},
|
||||
},
|
||||
"recordChangeTag": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
result = differ.diff_games(local_games, remote_records)
|
||||
|
||||
assert result.create_count == 0
|
||||
assert result.delete_count == 1
|
||||
assert result.deletes[0].record_name == sample_game.id
|
||||
|
||||
def test_diff_games_unchanged(self, differ, sample_game):
|
||||
"""Test detecting unchanged games."""
|
||||
local_games = [sample_game]
|
||||
remote_records = [
|
||||
{
|
||||
"recordName": sample_game.id,
|
||||
"recordType": "Game",
|
||||
"fields": {
|
||||
"sport": {"value": "nba", "type": "STRING"},
|
||||
"season": {"value": 2025, "type": "INT64"},
|
||||
"home_team_id": {"value": "team_nba_okc", "type": "STRING"},
|
||||
"away_team_id": {"value": "team_nba_hou", "type": "STRING"},
|
||||
"stadium_id": {"value": "stadium_nba_paycom_center", "type": "STRING"},
|
||||
"game_date": {"value": int(sample_game.game_date.timestamp() * 1000), "type": "TIMESTAMP"},
|
||||
"game_number": {"value": None, "type": "INT64"},
|
||||
"home_score": {"value": None, "type": "INT64"},
|
||||
"away_score": {"value": None, "type": "INT64"},
|
||||
"status": {"value": "scheduled", "type": "STRING"},
|
||||
},
|
||||
"recordChangeTag": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
result = differ.diff_games(local_games, remote_records)
|
||||
|
||||
assert result.create_count == 0
|
||||
assert result.update_count == 0
|
||||
assert result.unchanged_count == 1
|
||||
|
||||
def test_diff_games_update(self, differ, sample_game):
|
||||
"""Test detecting games that need update."""
|
||||
local_games = [sample_game]
|
||||
# Remote has different status
|
||||
remote_records = [
|
||||
{
|
||||
"recordName": sample_game.id,
|
||||
"recordType": "Game",
|
||||
"fields": {
|
||||
"sport": {"value": "nba", "type": "STRING"},
|
||||
"season": {"value": 2025, "type": "INT64"},
|
||||
"home_team_id": {"value": "team_nba_okc", "type": "STRING"},
|
||||
"away_team_id": {"value": "team_nba_hou", "type": "STRING"},
|
||||
"stadium_id": {"value": "stadium_nba_paycom_center", "type": "STRING"},
|
||||
"game_date": {"value": int(sample_game.game_date.timestamp() * 1000), "type": "TIMESTAMP"},
|
||||
"game_number": {"value": None, "type": "INT64"},
|
||||
"home_score": {"value": None, "type": "INT64"},
|
||||
"away_score": {"value": None, "type": "INT64"},
|
||||
"status": {"value": "postponed", "type": "STRING"}, # Different!
|
||||
},
|
||||
"recordChangeTag": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
result = differ.diff_games(local_games, remote_records)
|
||||
|
||||
assert result.update_count == 1
|
||||
assert "status" in result.updates[0].changed_fields
|
||||
assert result.updates[0].record_change_tag == "abc123"
|
||||
|
||||
def test_diff_teams_create(self, differ, sample_team):
|
||||
"""Test detecting new teams to create."""
|
||||
local_teams = [sample_team]
|
||||
remote_records = []
|
||||
|
||||
result = differ.diff_teams(local_teams, remote_records)
|
||||
|
||||
assert result.create_count == 1
|
||||
assert result.creates[0].record_name == sample_team.id
|
||||
|
||||
def test_diff_stadiums_create(self, differ, sample_stadium):
|
||||
"""Test detecting new stadiums to create."""
|
||||
local_stadiums = [sample_stadium]
|
||||
remote_records = []
|
||||
|
||||
result = differ.diff_stadiums(local_stadiums, remote_records)
|
||||
|
||||
assert result.create_count == 1
|
||||
assert result.creates[0].record_name == sample_stadium.id
|
||||
|
||||
def test_get_records_to_upload(self, differ, sample_game):
|
||||
"""Test getting CloudKitRecords for upload."""
|
||||
game2 = Game(
|
||||
id="nba_2025_lal_lac_1022",
|
||||
sport="nba",
|
||||
season=2025,
|
||||
home_team_id="team_nba_lac",
|
||||
away_team_id="team_nba_lal",
|
||||
stadium_id="stadium_nba_crypto_com",
|
||||
game_date=datetime(2025, 10, 22, 19, 0, 0),
|
||||
status="scheduled",
|
||||
)
|
||||
|
||||
local_games = [sample_game, game2]
|
||||
# Only game2 exists remotely with different status
|
||||
remote_records = [
|
||||
{
|
||||
"recordName": game2.id,
|
||||
"recordType": "Game",
|
||||
"fields": {
|
||||
"sport": {"value": "nba", "type": "STRING"},
|
||||
"season": {"value": 2025, "type": "INT64"},
|
||||
"home_team_id": {"value": "team_nba_lac", "type": "STRING"},
|
||||
"away_team_id": {"value": "team_nba_lal", "type": "STRING"},
|
||||
"stadium_id": {"value": "stadium_nba_crypto_com", "type": "STRING"},
|
||||
"game_date": {"value": int(game2.game_date.timestamp() * 1000), "type": "TIMESTAMP"},
|
||||
"status": {"value": "postponed", "type": "STRING"}, # Different!
|
||||
},
|
||||
"recordChangeTag": "xyz789",
|
||||
}
|
||||
]
|
||||
|
||||
result = differ.diff_games(local_games, remote_records)
|
||||
records = result.get_records_to_upload()
|
||||
|
||||
assert len(records) == 2 # 1 create + 1 update
|
||||
record_names = [r.record_name for r in records]
|
||||
assert sample_game.id in record_names
|
||||
assert game2.id in record_names
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Tests for module-level convenience functions."""
|
||||
|
||||
def test_game_to_cloudkit_record(self):
|
||||
"""Test converting Game to CloudKitRecord."""
|
||||
game = Game(
|
||||
id="nba_2025_hou_okc_1021",
|
||||
sport="nba",
|
||||
season=2025,
|
||||
home_team_id="team_nba_okc",
|
||||
away_team_id="team_nba_hou",
|
||||
stadium_id="stadium_nba_paycom_center",
|
||||
game_date=datetime(2025, 10, 21, 19, 0, 0),
|
||||
status="scheduled",
|
||||
)
|
||||
|
||||
record = game_to_cloudkit_record(game)
|
||||
|
||||
assert record.record_name == game.id
|
||||
assert record.record_type == RecordType.GAME
|
||||
assert record.fields["sport"] == "nba"
|
||||
assert record.fields["season"] == 2025
|
||||
|
||||
def test_team_to_cloudkit_record(self):
|
||||
"""Test converting Team to CloudKitRecord."""
|
||||
team = Team(
|
||||
id="team_nba_okc",
|
||||
sport="nba",
|
||||
city="Oklahoma City",
|
||||
name="Thunder",
|
||||
full_name="Oklahoma City Thunder",
|
||||
abbreviation="OKC",
|
||||
)
|
||||
|
||||
record = team_to_cloudkit_record(team)
|
||||
|
||||
assert record.record_name == team.id
|
||||
assert record.record_type == RecordType.TEAM
|
||||
assert record.fields["city"] == "Oklahoma City"
|
||||
assert record.fields["name"] == "Thunder"
|
||||
|
||||
def test_stadium_to_cloudkit_record(self):
|
||||
"""Test converting Stadium to CloudKitRecord."""
|
||||
stadium = Stadium(
|
||||
id="stadium_nba_paycom_center",
|
||||
sport="nba",
|
||||
name="Paycom Center",
|
||||
city="Oklahoma City",
|
||||
state="OK",
|
||||
country="USA",
|
||||
latitude=35.4634,
|
||||
longitude=-97.5151,
|
||||
)
|
||||
|
||||
record = stadium_to_cloudkit_record(stadium)
|
||||
|
||||
assert record.record_name == stadium.id
|
||||
assert record.record_type == RecordType.STADIUM
|
||||
assert record.fields["name"] == "Paycom Center"
|
||||
assert record.fields["latitude"] == 35.4634
|
||||
472
sportstime_parser/tests/test_uploaders/test_state.py
Normal file
472
sportstime_parser/tests/test_uploaders/test_state.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""Tests for the upload state manager."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from sportstime_parser.uploaders.state import (
|
||||
RecordState,
|
||||
UploadSession,
|
||||
StateManager,
|
||||
)
|
||||
|
||||
|
||||
class TestRecordState:
|
||||
"""Tests for RecordState dataclass."""
|
||||
|
||||
def test_create_record_state(self):
|
||||
"""Test creating a RecordState with default values."""
|
||||
state = RecordState(
|
||||
record_name="nba_2025_hou_okc_1021",
|
||||
record_type="Game",
|
||||
)
|
||||
|
||||
assert state.record_name == "nba_2025_hou_okc_1021"
|
||||
assert state.record_type == "Game"
|
||||
assert state.status == "pending"
|
||||
assert state.uploaded_at is None
|
||||
assert state.record_change_tag is None
|
||||
assert state.error_message is None
|
||||
assert state.retry_count == 0
|
||||
|
||||
def test_record_state_to_dict(self):
|
||||
"""Test serializing RecordState to dictionary."""
|
||||
now = datetime.utcnow()
|
||||
state = RecordState(
|
||||
record_name="nba_2025_hou_okc_1021",
|
||||
record_type="Game",
|
||||
uploaded_at=now,
|
||||
record_change_tag="abc123",
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
data = state.to_dict()
|
||||
|
||||
assert data["record_name"] == "nba_2025_hou_okc_1021"
|
||||
assert data["record_type"] == "Game"
|
||||
assert data["status"] == "uploaded"
|
||||
assert data["uploaded_at"] == now.isoformat()
|
||||
assert data["record_change_tag"] == "abc123"
|
||||
|
||||
def test_record_state_from_dict(self):
|
||||
"""Test deserializing RecordState from dictionary."""
|
||||
data = {
|
||||
"record_name": "nba_2025_hou_okc_1021",
|
||||
"record_type": "Game",
|
||||
"uploaded_at": "2026-01-10T12:00:00",
|
||||
"record_change_tag": "abc123",
|
||||
"status": "uploaded",
|
||||
"error_message": None,
|
||||
"retry_count": 0,
|
||||
}
|
||||
|
||||
state = RecordState.from_dict(data)
|
||||
|
||||
assert state.record_name == "nba_2025_hou_okc_1021"
|
||||
assert state.record_type == "Game"
|
||||
assert state.status == "uploaded"
|
||||
assert state.uploaded_at == datetime.fromisoformat("2026-01-10T12:00:00")
|
||||
assert state.record_change_tag == "abc123"
|
||||
|
||||
|
||||
class TestUploadSession:
|
||||
"""Tests for UploadSession dataclass."""
|
||||
|
||||
def test_create_upload_session(self):
|
||||
"""Test creating an UploadSession."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
|
||||
assert session.sport == "nba"
|
||||
assert session.season == 2025
|
||||
assert session.environment == "development"
|
||||
assert session.total_count == 0
|
||||
assert len(session.records) == 0
|
||||
|
||||
def test_add_record(self):
|
||||
"""Test adding records to a session."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.add_record("team_1", "Team")
|
||||
|
||||
assert session.total_count == 3
|
||||
assert len(session.records) == 3
|
||||
assert "game_1" in session.records
|
||||
assert session.records["game_1"].record_type == "Game"
|
||||
|
||||
def test_mark_uploaded(self):
|
||||
"""Test marking a record as uploaded."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
|
||||
session.mark_uploaded("game_1", "change_tag_123")
|
||||
|
||||
assert session.records["game_1"].status == "uploaded"
|
||||
assert session.records["game_1"].record_change_tag == "change_tag_123"
|
||||
assert session.records["game_1"].uploaded_at is not None
|
||||
|
||||
def test_mark_failed(self):
|
||||
"""Test marking a record as failed."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
|
||||
session.mark_failed("game_1", "Server error")
|
||||
|
||||
assert session.records["game_1"].status == "failed"
|
||||
assert session.records["game_1"].error_message == "Server error"
|
||||
assert session.records["game_1"].retry_count == 1
|
||||
|
||||
def test_mark_failed_increments_retry_count(self):
|
||||
"""Test that marking failed increments retry count."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
|
||||
session.mark_failed("game_1", "Error 1")
|
||||
session.mark_failed("game_1", "Error 2")
|
||||
session.mark_failed("game_1", "Error 3")
|
||||
|
||||
assert session.records["game_1"].retry_count == 3
|
||||
|
||||
def test_counts(self):
|
||||
"""Test session counts."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.add_record("game_3", "Game")
|
||||
|
||||
session.mark_uploaded("game_1")
|
||||
session.mark_failed("game_2", "Error")
|
||||
|
||||
assert session.uploaded_count == 1
|
||||
assert session.failed_count == 1
|
||||
assert session.pending_count == 1
|
||||
|
||||
def test_is_complete(self):
|
||||
"""Test is_complete property."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
|
||||
assert not session.is_complete
|
||||
|
||||
session.mark_uploaded("game_1")
|
||||
assert not session.is_complete
|
||||
|
||||
session.mark_uploaded("game_2")
|
||||
assert session.is_complete
|
||||
|
||||
def test_progress_percent(self):
|
||||
"""Test progress percentage calculation."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.add_record("game_3", "Game")
|
||||
session.add_record("game_4", "Game")
|
||||
|
||||
session.mark_uploaded("game_1")
|
||||
|
||||
assert session.progress_percent == 25.0
|
||||
|
||||
def test_get_pending_records(self):
|
||||
"""Test getting pending record names."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.add_record("game_3", "Game")
|
||||
|
||||
session.mark_uploaded("game_1")
|
||||
session.mark_failed("game_2", "Error")
|
||||
|
||||
pending = session.get_pending_records()
|
||||
|
||||
assert pending == ["game_3"]
|
||||
|
||||
def test_get_failed_records(self):
|
||||
"""Test getting failed record names."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.add_record("game_3", "Game")
|
||||
|
||||
session.mark_failed("game_1", "Error 1")
|
||||
session.mark_failed("game_3", "Error 3")
|
||||
|
||||
failed = session.get_failed_records()
|
||||
|
||||
assert set(failed) == {"game_1", "game_3"}
|
||||
|
||||
def test_get_retryable_records(self):
|
||||
"""Test getting records eligible for retry."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.add_record("game_3", "Game")
|
||||
|
||||
# Fail game_1 once
|
||||
session.mark_failed("game_1", "Error")
|
||||
|
||||
# Fail game_2 three times (max retries)
|
||||
session.mark_failed("game_2", "Error")
|
||||
session.mark_failed("game_2", "Error")
|
||||
session.mark_failed("game_2", "Error")
|
||||
|
||||
retryable = session.get_retryable_records(max_retries=3)
|
||||
|
||||
assert retryable == ["game_1"]
|
||||
|
||||
def test_to_dict_and_from_dict(self):
|
||||
"""Test round-trip serialization."""
|
||||
session = UploadSession(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
)
|
||||
session.add_record("game_1", "Game")
|
||||
session.add_record("game_2", "Game")
|
||||
session.mark_uploaded("game_1", "tag_123")
|
||||
|
||||
data = session.to_dict()
|
||||
restored = UploadSession.from_dict(data)
|
||||
|
||||
assert restored.sport == session.sport
|
||||
assert restored.season == session.season
|
||||
assert restored.environment == session.environment
|
||||
assert restored.total_count == session.total_count
|
||||
assert restored.uploaded_count == session.uploaded_count
|
||||
assert restored.records["game_1"].status == "uploaded"
|
||||
|
||||
|
||||
class TestStateManager:
|
||||
"""Tests for StateManager."""
|
||||
|
||||
def test_create_session(self):
|
||||
"""Test creating a new session."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
session = manager.create_session(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[
|
||||
("game_1", "Game"),
|
||||
("game_2", "Game"),
|
||||
("team_1", "Team"),
|
||||
],
|
||||
)
|
||||
|
||||
assert session.sport == "nba"
|
||||
assert session.season == 2025
|
||||
assert session.total_count == 3
|
||||
|
||||
# Check file was created
|
||||
state_file = Path(tmpdir) / "upload_state_nba_2025_development.json"
|
||||
assert state_file.exists()
|
||||
|
||||
def test_load_session(self):
|
||||
"""Test loading an existing session."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
# Create and save a session
|
||||
original = manager.create_session(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game")],
|
||||
)
|
||||
original.mark_uploaded("game_1", "tag_123")
|
||||
manager.save_session(original)
|
||||
|
||||
# Load it back
|
||||
loaded = manager.load_session("nba", 2025, "development")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.sport == "nba"
|
||||
assert loaded.records["game_1"].status == "uploaded"
|
||||
|
||||
def test_load_nonexistent_session(self):
|
||||
"""Test loading a session that doesn't exist."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
session = manager.load_session("nba", 2025, "development")
|
||||
|
||||
assert session is None
|
||||
|
||||
def test_delete_session(self):
|
||||
"""Test deleting a session."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
# Create a session
|
||||
manager.create_session(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game")],
|
||||
)
|
||||
|
||||
# Delete it
|
||||
result = manager.delete_session("nba", 2025, "development")
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
loaded = manager.load_session("nba", 2025, "development")
|
||||
assert loaded is None
|
||||
|
||||
def test_delete_nonexistent_session(self):
|
||||
"""Test deleting a session that doesn't exist."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
result = manager.delete_session("nba", 2025, "development")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_list_sessions(self):
|
||||
"""Test listing all sessions."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
# Create multiple sessions
|
||||
manager.create_session(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game")],
|
||||
)
|
||||
manager.create_session(
|
||||
sport="mlb",
|
||||
season=2026,
|
||||
environment="production",
|
||||
record_names=[("game_2", "Game"), ("game_3", "Game")],
|
||||
)
|
||||
|
||||
sessions = manager.list_sessions()
|
||||
|
||||
assert len(sessions) == 2
|
||||
sports = {s["sport"] for s in sessions}
|
||||
assert sports == {"nba", "mlb"}
|
||||
|
||||
def test_get_session_or_create_new(self):
|
||||
"""Test getting a session when none exists."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
session = manager.get_session_or_create(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game")],
|
||||
resume=False,
|
||||
)
|
||||
|
||||
assert session.sport == "nba"
|
||||
assert session.total_count == 1
|
||||
|
||||
def test_get_session_or_create_resume(self):
|
||||
"""Test resuming an existing session."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
# Create initial session
|
||||
original = manager.create_session(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game"), ("game_2", "Game")],
|
||||
)
|
||||
original.mark_uploaded("game_1", "tag_123")
|
||||
manager.save_session(original)
|
||||
|
||||
# Resume with additional records
|
||||
session = manager.get_session_or_create(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game"), ("game_2", "Game"), ("game_3", "Game")],
|
||||
resume=True,
|
||||
)
|
||||
|
||||
# Should have original progress plus new record
|
||||
assert session.records["game_1"].status == "uploaded"
|
||||
assert "game_3" in session.records
|
||||
assert session.total_count == 3
|
||||
|
||||
def test_get_session_or_create_overwrite(self):
|
||||
"""Test overwriting an existing session when not resuming."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
manager = StateManager(state_dir=Path(tmpdir))
|
||||
|
||||
# Create initial session
|
||||
original = manager.create_session(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_1", "Game"), ("game_2", "Game")],
|
||||
)
|
||||
original.mark_uploaded("game_1", "tag_123")
|
||||
manager.save_session(original)
|
||||
|
||||
# Create new session (not resuming)
|
||||
session = manager.get_session_or_create(
|
||||
sport="nba",
|
||||
season=2025,
|
||||
environment="development",
|
||||
record_names=[("game_3", "Game")],
|
||||
resume=False,
|
||||
)
|
||||
|
||||
# Should be a fresh session
|
||||
assert session.total_count == 1
|
||||
assert "game_1" not in session.records
|
||||
assert "game_3" in session.records
|
||||
Reference in New Issue
Block a user