Skip to content

Commit

Permalink
TST: update tests according to last Transmission modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrea Blengino committed Nov 25, 2023
1 parent 1ccc18e commit 821e714
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 28 deletions.
54 changes: 41 additions & 13 deletions tests/test_transmission/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,45 @@ def transmission_update_time_type_error(request):
transmission_snapshot_type_error_1 = [{'target_time': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, Time)]

transmission_snapshot_type_error_2 = [{'target_time': Time(1, 'sec'), 'angular_position_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]
transmission_snapshot_type_error_2 = [{'target_time': Time(1, 'sec'), 'variables': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, list)
and type_to_check is not None]

transmission_snapshot_type_error_3 = [{'target_time': Time(1, 'sec'), 'angular_speed_unit': type_to_check}
transmission_snapshot_type_error_3 = [{'target_time': Time(1, 'sec'), 'variables': [type_to_check]}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_4 = [{'target_time': Time(1, 'sec'), 'angular_acceleration_unit': type_to_check}
transmission_snapshot_type_error_4 = [{'target_time': Time(1, 'sec'), 'angular_position_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_5 = [{'target_time': Time(1, 'sec'), 'torque_unit': type_to_check}
transmission_snapshot_type_error_5 = [{'target_time': Time(1, 'sec'), 'angular_speed_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_6 = [{'target_time': Time(1, 'sec'), 'driving_torque_unit': type_to_check}
transmission_snapshot_type_error_6 = [{'target_time': Time(1, 'sec'), 'angular_acceleration_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_7 = [{'target_time': Time(1, 'sec'), 'load_torque_unit': type_to_check}
transmission_snapshot_type_error_7 = [{'target_time': Time(1, 'sec'), 'torque_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_8 = [{'target_time': Time(1, 'sec'), 'force_unit': type_to_check}
transmission_snapshot_type_error_8 = [{'target_time': Time(1, 'sec'), 'driving_torque_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_9 = [{'target_time': Time(1, 'sec'), 'stress_unit': type_to_check}
transmission_snapshot_type_error_9 = [{'target_time': Time(1, 'sec'), 'load_torque_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_10 = [{'target_time': Time(1, 'sec'), 'current_unit': type_to_check}
transmission_snapshot_type_error_10 = [{'target_time': Time(1, 'sec'), 'force_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_11 = [{'target_time': Time(1, 'sec'), 'stress_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_12 = [{'target_time': Time(1, 'sec'), 'current_unit': type_to_check}
for type_to_check in types_to_check if not isinstance(type_to_check, str)]

transmission_snapshot_type_error_11 = [{'target_time': Time(1, 'sec'), 'print_data': type_to_check}
transmission_snapshot_type_error_13 = [{'target_time': Time(1, 'sec'), 'print_data': type_to_check}
for type_to_check in types_to_check
if not isinstance(type_to_check, int) and not isinstance(type_to_check, bool)]

transmission_snapshot_type_error_12 = [{}]
transmission_snapshot_type_error_14 = [{}]

@fixture(params = [*transmission_snapshot_type_error_1,
*transmission_snapshot_type_error_2,
Expand All @@ -61,11 +68,32 @@ def transmission_update_time_type_error(request):
*transmission_snapshot_type_error_9,
*transmission_snapshot_type_error_10,
*transmission_snapshot_type_error_11,
*transmission_snapshot_type_error_12])
*transmission_snapshot_type_error_12,
*transmission_snapshot_type_error_13,
*transmission_snapshot_type_error_14])
def transmission_snapshot_type_error(request):
return request.param


transmission_snapshot_value_error_1 = [{'target_time': max(basic_transmission.time) + Time(1, 'sec')}]

transmission_snapshot_value_error_2 = [{'target_time': min(basic_transmission.time) - Time(1, 'sec')}]

transmission_snapshot_value_error_3 = [{'target_time': max(basic_transmission.time), 'variables': []}]

transmission_snapshot_value_error_4 = [{'target_time': max(basic_transmission.time), 'variables': ['not a valid time variable']}]

transmission_snapshot_value_error_5 = [{}]

@fixture(params = [*transmission_snapshot_value_error_1,
*transmission_snapshot_value_error_2,
*transmission_snapshot_value_error_3,
*transmission_snapshot_value_error_4,
*transmission_snapshot_value_error_5])
def transmission_snapshot_value_error(request):
return request.param


elements = [basic_transmission.chain[0]]
variables = list(basic_transmission.chain[0].time_variables.keys())

Expand Down
44 changes: 29 additions & 15 deletions tests/test_transmission/test_transmission.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from gearpy.mechanical_object import DCMotor, SpurGear
from gearpy.transmission import Transmission
from gearpy.units import AngularAcceleration, AngularPosition, AngularSpeed, Current, Force, InertiaMoment, Length, \
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_raises_name_error(self):
@mark.transmission
class TestTransmissionUpdateTime:


@mark.genuine
@given(transmission = transmissions(),
instant = times())
Expand All @@ -90,6 +92,7 @@ def test_raises_type_error(self, transmission_update_time_type_error):
@mark.transmission
class TestTransmissionSnapshot:


@mark.genuine
@given(solved_transmission = solved_transmissions(),
target_time_fraction = floats(min_value = 1e-10, max_value = 1 - 1e-10, allow_nan = False, allow_infinity = False),
Expand Down Expand Up @@ -118,31 +121,42 @@ def test_method(self, solved_transmission, target_time_fraction, angular_positio
stress_unit = stress_unit, current_unit = current_unit,
print_data = print_data)

columns = [f'angular position ({angular_position_unit})', f'angular speed ({angular_speed_unit})',
f'angular acceleration ({angular_acceleration_unit})', f'torque ({torque_unit})',
f'driving torque ({driving_torque_unit})', f'load torque ({load_torque_unit})',
f'tangential force ({force_unit})', f'bending stress ({stress_unit})',
f'contact stress ({stress_unit})']
if solved_transmission.chain[0].electrical_current_is_computable:
columns.append(f'electrical current ({current_unit})')

assert isinstance(data, pd.DataFrame)
assert [element.name for element in solved_transmission.chain] == data.index.to_list()
assert [f'angular position ({angular_position_unit})', f'angular speed ({angular_speed_unit})',
f'angular acceleration ({angular_acceleration_unit})', f'torque ({torque_unit})',
f'driving torque ({driving_torque_unit})', f'load torque ({load_torque_unit})',
f'tangential force ({force_unit})', f'bending stress ({stress_unit})',
f'contact stress ({stress_unit})', f'electrical current ({current_unit})'] == data.columns.to_list()
assert columns == data.columns.to_list()


@mark.error
def test_raises_type_error(self, transmission_snapshot_type_error):
with raises(TypeError):
if transmission_snapshot_type_error:
if transmission_snapshot_type_error:
with raises(TypeError):
basic_transmission.snapshot(**transmission_snapshot_type_error)
else:
basic_transmission.update_time(Time(1, 'sec'))
basic_transmission.time[0] = 1
basic_transmission.snapshot(target_time = Time(1, 'sec'))
else:
transmission_copy = deepcopy(basic_transmission)
transmission_copy.update_time(Time(1, 'sec'))
transmission_copy.time[0] = 1
with raises(TypeError):
transmission_copy.snapshot(target_time = Time(1, 'sec'))


@mark.error
def test_raises_value_error(self):
with raises(ValueError):
basic_transmission.time.clear()
basic_transmission.snapshot(target_time = Time(1, 'sec'))
def test_raises_value_error(self, transmission_snapshot_value_error):
if transmission_snapshot_value_error:
with raises(ValueError):
basic_transmission.snapshot(**transmission_snapshot_value_error)
else:
transmission_copy = deepcopy(basic_transmission)
transmission_copy.time.clear()
with raises(ValueError):
transmission_copy.snapshot(target_time = Time(1, 'sec'))


@mark.transmission
Expand Down

0 comments on commit 821e714

Please sign in to comment.