import pytest import os from pathlib import Path from typing import List, Set, Dict, Tuple, Optional, Any from datetime import datetime, timedelta import numpy as np import pandas as pd import xarray as xr import random import string from tests.TestUtils import TestUtils from tests.TestHelper import TestHelper as th from tests.ChunkData import ChunkDataUtils from SDToolBox.input_data import InputData from SDToolBox.predictor_definition import PredictorDefinition from SDToolBox.output_data import OutputData from SDToolBox.extract_data import ExtractData import SDToolBox.output_messages as om @pytest.fixture(scope="module") def valid_gridded_netcdf_xarray() -> xr.Dataset: # 1. Given basic_inputdata = InputData() basic_inputdata.input_variables = ["swh"] basic_inputdata.input_years = [1981, 1982] dir_test_data = Path(TestUtils.get_local_test_data_dir("chunked_data")) output_test_data = TestUtils.get_output_test_data_dir("output_data") default_chunked_lon_lat = ChunkDataUtils.get_default_chunked_lon_lat() basic_inputdata.input_coordinates = { "LON": default_chunked_lon_lat[0][1:2], "LAT": default_chunked_lon_lat[1][1:2], } basic_inputdata.is_gridded = False output_data = ExtractData.get_era_5( directory_path=dir_test_data, input_data=basic_inputdata ) ChunkDataUtils.assert_valid_ndarray(output_data.variables["swh"]) netcdf_filepath = output_data.generate_netcdf( dir_path=output_test_data, base_name="test_non_gridded", dataset_code=None, ) x_dataset = xr.open_dataset(netcdf_filepath) return x_dataset class Test_SpatialGradientsCalculation: @pytest.mark.systemtest def test_given_chunked_data_compute_slp_gradients(self): # 1. Get test data dir_test_data = TestUtils.get_local_test_data_dir("generated_output_data") test_nc_file = dir_test_data / "test_swh_1981_to_1986_CF.nc" assert test_nc_file.is_file(), "" + "Test file not found at {}".format( test_nc_file ) output_xarray = xr.open_dataset(test_nc_file) # 2. Run test. result = PredictorDefinition.compute_spatial_gradients(output_xarray) # 3. Validate extracted data. assert result is not None assert isinstance(result, xr.DataArray) assert result.dims == output_xarray.dims @pytest.mark.unittest def test_given_an_xarray_compute_slp_gradients(self): lat = [43.125, 50, 60.125, 13.5] lon = [43.125, 50, 60.125, 13.5] times = pd.date_range("2000-01-01", periods=4) data = np.random.rand(len(times), len(lat), len(lon)) data_array = xr.DataArray( data, coords=[times, lat, lon], dims=["time", "lat", "lon"] ) # result = np.gradient(data_array, axis=0) result = PredictorDefinition.compute_spatial_gradients(data_array) assert result is not None assert isinstance(result, xr.DataArray) assert result.dims == data_array.dims class Test_ComputePCA: @pytest.mark.unittest def test_given_input_data_compute_PCA(self): grad2slp = np.random.rand(10, 3, 4) slp = np.random.rand(10, 3, 4) array_dataset = th.get_default_xarray_dataset() array_dataset.attrs["msl_p"] = slp array_dataset.attrs["grad2slp"] = grad2slp data_processing = PredictorDefinition() [mean, deviation, variance] = data_processing.compute_PCA(array_dataset) assert mean is not None assert deviation is not None assert variance is not None class Test_AtmosphericAggregated: @pytest.mark.unittest def test_when_no_dataset_raises(self): # 1. Given expected_error = om.error_all_arguments_required result = None # 1. When with pytest.raises(Exception) as e_info: result = PredictorDefinition.atmospheric_aggregated( dataset=None, time_scale=42 ) # 2. Then assert result is None assert str(e_info.value) == expected_error @pytest.mark.unittest def test_when_no_time_scale_raises(self): # 1. Given expected_error = om.error_all_arguments_required result = None # 1. When with pytest.raises(Exception) as e_info: result = PredictorDefinition.atmospheric_aggregated( dataset=th.get_default_xarray_dataset(), time_scale=None ) # 2. Then assert result is None assert str(e_info.value) == expected_error @pytest.mark.unittest def test_when_all_args_given_does_not_raise(self): # 1. Given # dataset = th.get_xarray_gridded_dataset('SWH') dataset = th.get_gridded_xarray_from_file() time_scale = 2 input_vars = ["SWH"] result_ds = None result_dict = None # 2. When try: result_ds, result_dict = PredictorDefinition.atmospheric_aggregated( dataset=dataset, time_scale=time_scale, var_list=input_vars ) except Exception as e_info: pytest.fail("Exception thrown but not expected," + "{}".format(str(e_info))) # 3. Then assert result_ds is not None assert result_dict is not None assert "SWH" in result_dict.keys() assert len(result_dict["SWH"]) == time_scale @pytest.mark.integrationtest @pytest.mark.parametrize("var_list", [(None), ([]), (["dummy"])]) def test_when_no_valid_var_list_then_raises(self, var_list: List[str]): # 1. Given dataset = th.get_gridded_xarray_from_file() time_scale = 42 output_result = None available_vars = list(dataset.variables.mapping.keys()) expected_message = om.error_pd_input_list_not_valid.format(available_vars) # 2. When with pytest.raises(Exception) as e_info: output_result = PredictorDefinition.atmospheric_aggregated( dataset=dataset, time_scale=time_scale, var_list=var_list, ) assert output_result is None assert str(e_info.value) == expected_message @pytest.mark.systemtest def test_given_gridded_valid_data_output_does_not_raise( self, valid_gridded_netcdf_xarray: xr.Dataset ): # 1. When sliced_result = PredictorDefinition.atmospheric_aggregated( dataset=valid_gridded_netcdf_xarray, time_scale=1, var_list=["SWH"] ) # 2. Then output_ds, variable_predictors = sliced_result assert isinstance(output_ds, xr.Dataset) assert isinstance(variable_predictors, dict) assert isinstance(variable_predictors["SWH"], list) assert len(variable_predictors["SWH"]) == 1 assert isinstance(variable_predictors["SWH"][0], xr.DataArray) class Test_AtmosphericAveraged: @staticmethod def get_output_chunked_gridded_file() -> xr.DataArray: """Generates a gridded xarray based on a test file. Returns: xarray.DataArray -- Output xarray data. """ dir_test_data = TestUtils.get_local_test_data_dir("generated_output_data") test_nc_file = dir_test_data / "test_swh_1981_to_1986_CF.nc" assert test_nc_file.is_file(), "" + "Test file not found at {}".format( test_nc_file ) return xr.open_dataset(test_nc_file) @pytest.mark.unittest def test_when_mean_given_no_args_then_raises(self): # 1. Given dataset = None t_aggr = None result = None exp_mssg = om.error_all_arguments_required # 2. When with pytest.raises(Exception) as e_info: result = PredictorDefinition.atmospheric_averaged_mean( dataset=dataset, t_days=t_aggr ) # 3. Then assert result is None assert str(e_info.value) == exp_mssg @pytest.mark.unittest def test_when_mean_given_args_then_does_not_raise(self): # 1. Given dataset = self.get_output_chunked_gridded_file() t_aggr = 3 # 2. When averaged_mean = PredictorDefinition.atmospheric_averaged_mean( dataset=dataset, t_days=t_aggr ) # 3. Then assert averaged_mean, "No averaged mean was calculated." @pytest.mark.unittest def test_when_median_given_args_then_does_not_raise(self): # 1. Given dataset = th.get_gridded_xarray_from_file() t_aggr = 3 # 2. When result = PredictorDefinition.atmospheric_averaged_median( dataset=dataset, t_days=t_aggr ) # 3. Then assert result is not None @pytest.mark.unittest def test_when_max_given_args_then_does_not_raise(self): # 1. Given dataset = th.get_gridded_xarray_from_file() t_aggr = 3 # 2. When result = PredictorDefinition.atmospheric_averaged_max( dataset=dataset, t_days=t_aggr ) # 3. Then assert result is not None @pytest.mark.systemtest def test_given_gridded_valid_data_output_does_not_raise( self, valid_gridded_netcdf_xarray: xr.Dataset ): # 1. When sliced_result = PredictorDefinition.atmospheric_averaged_max( dataset=valid_gridded_netcdf_xarray, t_days=1 ) # 2. Then assert isinstance(sliced_result, xr.Dataset) class Test_GetDatasetTimeSlice: @pytest.mark.unittest def test_given_no_args_then_raises(self): # 1. Given dataset = None t_aggr = None result = None exp_mssg = om.error_all_arguments_required # 2. When with pytest.raises(Exception) as e_info: result = PredictorDefinition.get_dataset_time_slice( dataset=dataset, t_days=t_aggr ) # 3. Then assert result is None assert str(e_info.value) == exp_mssg @pytest.mark.unittest def test_given_no_gridded_dataset_then_raises(self): # 1. Given dataset = th.get_default_xarray_dataset() t_aggr = 4 result = None exp_mssg = om.error_only_for_gridded_dataset # 2. When with pytest.raises(Exception) as e_info: result = PredictorDefinition.get_dataset_time_slice( dataset=dataset, t_days=t_aggr ) # 3. Then assert result is None assert str(e_info.value) == exp_mssg @pytest.mark.unittest def test_given_xarray_then_gets_subset(self): # 1. Given x_dataset = th.get_gridded_xarray_from_file() t_aggr = 2 sliced_result = None # 2. When sliced_result = PredictorDefinition.get_dataset_time_slice( dataset=x_dataset, t_days=1 ) # 3. Then assert sliced_result is not None @pytest.mark.systemtest def test_given_gridded_valid_data_output_does_not_raise( self, valid_gridded_netcdf_xarray: xr.Dataset ): # 1. When sliced_result = PredictorDefinition.get_dataset_time_slice( dataset=valid_gridded_netcdf_xarray, t_days=1 ) # 2. Then assert isinstance(sliced_result, xr.Dataset)