from datetime import date from django.test import TestCase from django.contrib.auth.models import User from rest_framework.test import APIClient from rest_framework.authtoken.models import Token from registered_user.models import RegisteredUser from generator.models import UserPreference, WorkoutType class TestInjurySafety(TestCase): """Tests for injury-related preference round-trip and warning generation.""" def setUp(self): self.django_user = User.objects.create_user( username='testuser', password='testpass123', email='test@example.com', ) self.registered_user = RegisteredUser.objects.create( user=self.django_user, first_name='Test', last_name='User', ) self.token = Token.objects.create(user=self.django_user) self.client = APIClient() self.client.credentials(HTTP_AUTHORIZATION=f'Token {self.token.key}') self.preference = UserPreference.objects.create( registered_user=self.registered_user, days_per_week=3, ) # Create a basic workout type for generation self.workout_type = WorkoutType.objects.create( name='functional_strength_training', display_name='Functional Strength', typical_rest_between_sets=60, typical_intensity='medium', rep_range_min=8, rep_range_max=12, round_range_min=3, round_range_max=4, duration_bias=0.3, superset_size_min=2, superset_size_max=4, ) def test_injury_types_roundtrip(self): """PUT injury_types, GET back, verify data persists.""" injuries = [ {'type': 'knee', 'severity': 'moderate'}, {'type': 'shoulder', 'severity': 'mild'}, ] response = self.client.put( '/generator/preferences/update/', {'injury_types': injuries}, format='json', ) self.assertEqual(response.status_code, 200) # GET back response = self.client.get('/generator/preferences/') self.assertEqual(response.status_code, 200) data = response.json() self.assertEqual(len(data['injury_types']), 2) types_set = {i['type'] for i in data['injury_types']} self.assertIn('knee', types_set) self.assertIn('shoulder', types_set) def test_injury_types_validation_rejects_invalid_type(self): """Invalid injury type should be rejected.""" response = self.client.put( '/generator/preferences/update/', {'injury_types': [{'type': 'elbow', 'severity': 'mild'}]}, format='json', ) self.assertEqual(response.status_code, 400) def test_injury_types_validation_rejects_invalid_severity(self): """Invalid severity should be rejected.""" response = self.client.put( '/generator/preferences/update/', {'injury_types': [{'type': 'knee', 'severity': 'extreme'}]}, format='json', ) self.assertEqual(response.status_code, 400) def test_severe_knee_excludes_high_impact(self): """Set knee:severe, verify the exercise selector filters correctly.""" from generator.services.exercise_selector import ExerciseSelector self.preference.injury_types = [ {'type': 'knee', 'severity': 'severe'}, ] self.preference.save() selector = ExerciseSelector(self.preference) qs = selector._get_filtered_queryset() # No high-impact exercises should remain high_impact = qs.filter(impact_level='high') self.assertEqual(high_impact.count(), 0) # No medium-impact exercises either (severe lower body) medium_impact = qs.filter(impact_level='medium') self.assertEqual(medium_impact.count(), 0) # Warnings should mention the injury self.assertTrue( any('knee' in w.lower() for w in selector.warnings), f'Expected knee-related warning, got: {selector.warnings}' ) def test_no_injuries_full_pool(self): """Empty injury_types should not exclude any exercises.""" from generator.services.exercise_selector import ExerciseSelector self.preference.injury_types = [] self.preference.save() selector = ExerciseSelector(self.preference) qs = selector._get_filtered_queryset() # With no injuries, there should be no injury-based warnings injury_warnings = [w for w in selector.warnings if 'injury' in w.lower()] self.assertEqual(len(injury_warnings), 0) def test_warnings_in_preview_response(self): """With injuries set, verify warnings key appears in preview response.""" self.preference.injury_types = [ {'type': 'knee', 'severity': 'moderate'}, ] self.preference.save() self.preference.preferred_workout_types.add(self.workout_type) response = self.client.post( '/generator/preview/', {'week_start_date': '2026-03-02'}, format='json', ) # Should succeed (200) even if exercise pool is limited self.assertIn(response.status_code, [200, 500]) if response.status_code == 200: data = response.json() # The warnings key should exist if injuries triggered any warnings if 'warnings' in data: self.assertIsInstance(data['warnings'], list) def test_backward_compat_string_injuries(self): """Legacy string format should be accepted and normalized.""" response = self.client.put( '/generator/preferences/update/', {'injury_types': ['knee', 'shoulder']}, format='json', ) self.assertEqual(response.status_code, 200) # Verify normalized to dict format response = self.client.get('/generator/preferences/') data = response.json() for injury in data['injury_types']: self.assertIn('type', injury) self.assertIn('severity', injury) self.assertEqual(injury['severity'], 'moderate')