Index: trunk/tests/test_extract_waves.py =================================================================== diff -u -r38 -r40 --- trunk/tests/test_extract_waves.py (.../test_extract_waves.py) (revision 38) +++ trunk/tests/test_extract_waves.py (.../test_extract_waves.py) (revision 40) @@ -4,11 +4,20 @@ import netCDF4 from tests.TestUtils import TestUtils from netCDF4 import Dataset + from SDToolBox import main as main -from SDToolBox import extract_data_era5 +from SDToolBox import extract_waves from SDToolBox import data_acquisition +from SDToolBox import outputmessages as om + import numpy as np +test_variable_dict = { + 'swh': 'Hs', + 'pp1d': 'Tp', + 'mwd': 'MWD', + 'mwp': 'Tm', +} class Test_create: @@ -17,10 +26,10 @@ # 1. Given input_data = None output_result = None - expected_error = 'No valid input data.' + expected_error = om.error_no_valid_input_data # 2. When with pytest.raises(IOError) as e_info: - output_result = extract_data_era5.ExtractDataEra5(input_data) + output_result = extract_waves.ExtractWaves(input_data) # 3. Then error_message = str(e_info.value) @@ -29,7 +38,44 @@ 'Expected exception message {},'.format(expected_error) + \ 'retrieved {}'.format(error_message) + @pytest.mark.unittest + def test_given_no_input_coords_then_exception_is_risen(self): + # 1. Given + input_data = data_acquisition.InputData() + input_data.coord_list = [] + data_extractor = None + expected_error = om.error_not_enough_coordinates + # 2. When + with pytest.raises(IOError) as e_info: + data_extractor = extract_waves.ExtractWaves(input_data) + # 3. Then + error_message = str(e_info.value) + assert data_extractor is None + assert error_message == expected_error, '' + \ + 'Expected exception message {},'.format(expected_error) + \ + 'retrieved {}'.format(error_message) + + @pytest.mark.unittest + def test_given_no_dict_of_vars_then_exception_is_risen(self): + # 1. Given + input_data = data_acquisition.InputData() + input_data.coord_list = [(4.2, 42)] + input_data.input_dict = None + data_extractor = None + expected_error = om.error_no_valid_dict_of_vars + # 2. When + with pytest.raises(IOError) as e_info: + data_extractor = extract_waves.ExtractWaves(input_data) + + # 3. Then + error_message = str(e_info.value) + assert data_extractor is None + assert error_message == expected_error, '' + \ + 'Expected exception message {},'.format(expected_error) + \ + 'retrieved {}'.format(error_message) + + class Test_subset_era_5: @pytest.mark.systemtest @@ -40,11 +86,12 @@ # dir_test_data = 'P:\\metocean-data\\open\\ERA5\\data\\Global' input_data = data_acquisition.InputData() - input_data.coord_list = [(4.2, 2.4),] + input_data.input_dict = test_variable_dict + input_data.coord_list = [(4.2, 2.4), ] # 2. When - extract_wave = extract_data_era5.ExtractDataEra5(input_data) - output_data = extract_wave.subset_waves(dir_test_data, 1981, 1982) + extract_wave = extract_waves.ExtractWaves(input_data) + output_data = extract_wave.subset_era_5(dir_test_data, 1981, 1982) # 3. Then assert output_data is not None @@ -60,40 +107,42 @@ def test_given_waves_folder_then_subset_collection_is_extracted_with_coord_list(self): # 1. Given # When using local data you can just replace the comment in these lines - dir_test_data = TestUtils.get_local_test_data_dir('netCDF_Waves_data') - # dir_test_data = 'P:\\metocean-data\\open\\ERA5\\data\\Global' + # dir_test_data = TestUtils.get_local_test_data_dir('netCDF_Waves_data') + dir_test_data = 'P:\\metocean-data\\open\\ERA5\\data\\Global' input_data = data_acquisition.InputData() + input_data.input_dict = test_variable_dict input_data.coord_list = [(4.2, 2.4), (42, 2.4), (42, 24), (4.2, 24)] # 2. When - extract_wave = extract_data_era5.ExtractDataEra5(input_data) - dataset_list = extract_wave.subset_waves(dir_test_data, 1981, 1982) + extract_wave = extract_waves.ExtractWaves(input_data) + dataset_list = extract_wave.subset_era_5(dir_test_data, 1981, 1982) # 3. Then assert dataset_list is not None - + """ - Checks that the longitude is normalized if a value higher than 180 is passed + Checks that the longitude is normalized + if a value higher than 180 is passed """ @pytest.mark.unittest def test_given_longitude_higher_than_180_returns_normalized_longitude(self): - #setup + # setup longitude = 400 - normalized_longitude = longitude -180 + normalized_longitude = longitude - 180 dir_test_data = 'P:\\metocean-data\\open\\ERA5\\data\\Global\\Hs' filename = 'era5_Global_Hs_1980.nc' path = dir_test_data + filename with Dataset(path, 'r', format='netCDF4') as case_dataset: input_data = data_acquisition.InputData() input_data.coord_list = [(4.2, 2.4)] - extractwaves = extract_data_era5.ExtractDataEra5(input_data) - - #call + extractwaves = extract_waves.ExtractWaves(input_data) + + # call result = extractwaves.check_for_longitude(longitude) - #assert + # assert assert result == normalized_longitude - + """ Checks that the longitude is unchanged if a value lower than 180 is passed """ @@ -105,34 +154,36 @@ path = dir_test_data + filename with Dataset(path, 'r', format='netCDF4') as case_dataset: input_data = data_acquisition.InputData() + input_data.input_dict = test_variable_dict input_data.coord_list = [(4.2, 2.4)] - extractwaves = extract_data_era5.ExtractDataEra5(input_data) - - #call + extractwaves = extract_waves.ExtractWaves(input_data) + + # call result = extractwaves.check_for_longitude(longitude) - #assert + # assert assert result == 30 """ Checks that array of years if correctly generated """ @pytest.mark.unittest def test_years_array_is_correctly_generated(self): - #setup + # setup year1 = 1980 yearN = 1983 - result_array = [1980,1981,1982,1983] + result_array = [1980, 1981, 1982, 1983] result = [] dir_test_data = 'P:\\metocean-data\\open\\ERA5\\data\\Global\\Hs' filename = 'era5_Global_Hs_1980.nc' path = dir_test_data + filename with Dataset(path, 'r', format='netCDF4') as case_dataset: input_data = data_acquisition.InputData() + input_data.input_dict = test_variable_dict input_data.coord_list = [(4.2, 2.4)] - extractwaves = extract_data_era5.ExtractDataEra5(input_data) - #call + extractwaves = extract_waves.ExtractWaves(input_data) + # call result = extractwaves.generate_years_array(year1, yearN) - #assert + # assert assert result[0] == result_array[0] assert result[1] == result_array[1] assert result[2] == result_array[2]