Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r76 -r91 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 76) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 91) @@ -18,6 +18,7 @@ from datetime import datetime, timedelta from netCDF4 import Dataset +from sklearn.neighbors import BallTree as BallTree import numpy as np # endregion @@ -124,6 +125,10 @@ def var_dict(self) -> Dict[str, str]: raise Exception(om.error_needs_to_be_in_subclass) + @abstractmethod + def mask_is_needed(self, input_variables: List[str]) -> bool: + raise Exception(om.error_needs_to_be_in_subclass) + @property def __first_filepath(self) -> str: return self.__file_iterator[0][self.file_fpath_key] @@ -157,6 +162,15 @@ def get_new_file_iter(self, file_entry, file_path) -> Dict[str, str]: raise Exception(om.error_needs_to_be_in_subclass) + @abstractmethod + def get_masked_nn_lat_lon( + self, + input_data: InputData, + ref_dataset: Dataset, + cases_dict: dict, + mask_dir: str) -> Tuple[List[int], List[int]]: + raise Exception(om.error_needs_to_be_in_subclass) + def set_file_iterator( self, input_data: InputData, @@ -323,21 +337,20 @@ -- Indices of nearest neighbors. """ # Extract index and value for all input lat, lon. + maskneeded = self.mask_is_needed(input_data.input_variables) with Dataset(ref_file_path, 'r', self.netcdf_format) \ as ref_dataset: - nn_lat_idx = self.__set_nn( - input_values=input_data._input_lat, - reference_list=ref_dataset.variables[self.lat_key][:], - output_values=cases_dict[OutputData.var_lat_key], - is_gridded=input_data.is_gridded - ) - nn_lon_idx = self.__set_nn( - input_values=input_data._input_lon, - reference_list=ref_dataset.variables[self.lon_key][:], - output_values=cases_dict[OutputData.var_lon_key], - is_gridded=input_data.is_gridded - ) - return nn_lon_idx, nn_lat_idx + if not maskneeded: + return self._get_unmasked_nn_lat_lon( + input_data=input_data, + ref_dataset=ref_dataset, + cases_dict=cases_dict) + else: + return self.get_masked_nn_lat_lon( + input_data=input_data, + ref_dataset=ref_dataset, + cases_dict=cases_dict, + mask_dir=os.path.dirname(os.path.dirname(ref_file_path))) def _get_filtered_dict( self, @@ -357,8 +370,27 @@ for k, v in values_dict.items() if k in values_selected} - def __set_nn( + def _get_unmasked_nn_lat_lon( self, + input_data: InputData, + ref_dataset: Dataset, + cases_dict: dict) -> Tuple[List[int], List[int]]: + nn_lat_idx = self._set_nn( + input_values=input_data._input_lat, + reference_list=ref_dataset.variables[self.lat_key][:], + output_values=cases_dict[OutputData.var_lat_key], + is_gridded=input_data.is_gridded + ) + nn_lon_idx = self._set_nn( + input_values=input_data._input_lon, + reference_list=ref_dataset.variables[self.lon_key][:], + output_values=cases_dict[OutputData.var_lon_key], + is_gridded=input_data.is_gridded + ) + return nn_lat_idx, nn_lon_idx + + def _set_nn( + self, input_values: List[float], reference_list: List[float], output_values: List[float], @@ -410,6 +442,23 @@ return index_found, value_found @staticmethod + def get_nearest_neighbor_extended( + values, + data_array) -> Tuple[int, int]: + """ + search for nearest decimal degree in an array of decimal + degrees and return the index. + np.argmin returns the indices of minium value along an axis. + so subtract value from all values in data_array, + take absolute value and find index of minium. + """ + kd = BallTree(data_array, leaf_size=10) + # k=2 nearest neighbors where k1 = identity + _, index_found = kd.query(values) + value_found = data_array[index_found.squeeze(), :] + return index_found, value_found + + @staticmethod def __get_case_subset( dataset: Dataset, variable_name: str, @@ -447,11 +496,15 @@ __era5_lon_key = 'longitude' __era5_lat_key = 'latitude' - __era5_var_dict = { + """swh, pp1d, mwd and mwp are WAVES variables. """ + __era5_wave_dict = { 'swh': 'Hs', 'pp1d': 'Tp', 'mwd': 'MWD', 'mwp': 'Tm', + } + + __era5_wind_mslp_dict = { 'msl': 'msl', 'u10': 'wind_u', 'v10': 'wind_v' @@ -467,8 +520,16 @@ @property def var_dict(self) -> Dict[str, str]: - return self.__era5_var_dict + return { + **self.__era5_wave_dict, + **self.__era5_wind_mslp_dict} + def mask_is_needed(self, input_variables: List[str]) -> bool: + mask_variables = self._get_filtered_dict( + input_variables, + self.__era5_wave_dict) + return bool(mask_variables) + def get_filepath( self, dir_path: str, @@ -522,7 +583,7 @@ """ filtered_dict = self._get_filtered_dict( input_data.input_variables, - self.__era5_var_dict) + self.var_dict) return itertools.product( filtered_dict.items(), input_data.input_years) @@ -535,6 +596,46 @@ self.file_fpath_key: file_path } + def get_masked_nn_lat_lon( + self, + input_data: InputData, + ref_dataset: Dataset, + cases_dict: dict, + mask_dir: str) -> Tuple[List[int], List[int]]: + mask_filepath = os.path.join(mask_dir, 'ERA5_landsea_mask.nc') + with Dataset(mask_filepath, 'r', self.netcdf_format) \ + as dsmask: + mask = dsmask['lsm'][0, 0::2, 0::2] + lon = dsmask['longitude'][0::2] + lon[lon > 180] += -360 + lat = dsmask['latitude'][0::2] + loni = np.where(lon == 180)[0][0] + LON, LAT = np.meshgrid(lon, lat) + LON = np.concatenate( + (LON[:, loni+1:], LON[:, 0:loni+1]), + axis=1) + mask = np.concatenate( + (mask[:, loni+1:], mask[:, 0:loni+1]), + axis=1) + # positions where we are in the sea + possea = np.where(mask.ravel() == 0.)[0] + index, vals = self.get_nearest_neighbor_extended( + np.vstack( + ( + input_data._input_lon, + input_data._input_lat + )).T, + np.vstack( + ( + LON.ravel()[possea], + LAT.ravel()[possea])).T) + index_abs = possea[index].squeeze() + indexes = np.array( + [ + np.unravel_index(ii, LON.shape) + for ii in index_abs]) + return indexes + class __EarthExtractor(BaseExtractor): __earth_lon_key = 'lon' __earth_lat_key = 'lat' @@ -557,6 +658,9 @@ def var_dict(self) -> Dict[str, str]: return self.__earth_var_dict + def mask_is_needed(self, input_variables: List[str]) -> bool: + return False + def get_case_time_values( self, ymd: str, @@ -637,3 +741,11 @@ self.file_scenario_key: file_entry[3], self.file_fpath_key: file_path } + + def get_masked_nn_lat_lon( + self, + input_data: InputData, + ref_dataset: Dataset, + cases_dict: dict, + mask_dir: str) -> Tuple[List[int], List[int]]: + raise Exception(om.error_function_not_implemented) Index: trunk/SDToolBox/input_data.py =================================================================== diff -u -r90 -r91 --- trunk/SDToolBox/input_data.py (.../input_data.py) (revision 90) +++ trunk/SDToolBox/input_data.py (.../input_data.py) (revision 91) @@ -63,25 +63,11 @@ is_gridded {bool} -- Coordinates represent a grid. (default: {False}) """ - if input_coordinates is None or not input_coordinates: - raise Exception( - om.error_not_enough_coordinates) - if not isinstance(input_coordinates, dict): - raise Exception( - om.error_wrong_format_input_coordinates_not_a_dictionary) - if isinstance(input_years, np.ndarray): self.input_years = input_years.tolist() else: self.input_years = input_years - if self.longitude_key not in input_coordinates: - raise Exception( - om.error_wrong_format_input_coordinates_LON_missing) - if self.latitude_key not in input_coordinates: - raise Exception( - om.error_wrong_format_input_coordinates_LAT_missing) - self.input_coordinates = input_coordinates self.input_variables = input_variables self.input_scenarios = input_scenarios @@ -117,6 +103,12 @@ if not self.input_coordinates or \ len(self.input_coordinates) < 1: raise IOError(om.error_not_enough_coordinates) + if self.longitude_key not in self.input_coordinates: + raise Exception( + om.error_wrong_format_input_coordinates_LON_missing) + if self.latitude_key not in self.input_coordinates: + raise Exception( + om.error_wrong_format_input_coordinates_LAT_missing) self.__extract_input_lon_lat() return True Index: trunk/tests/test_extract_data.py =================================================================== diff -u -r75 -r91 --- trunk/tests/test_extract_data.py (.../test_extract_data.py) (revision 75) +++ trunk/tests/test_extract_data.py (.../test_extract_data.py) (revision 91) @@ -7,7 +7,6 @@ from tests.TestUtils import TestUtils from netCDF4 import Dataset -from SDToolBox import main as main from SDToolBox.input_data import InputData from SDToolBox.extract_data import ExtractData import SDToolBox.output_messages as om @@ -57,7 +56,7 @@ input_data = InputData() input_data.input_variables = input_variables - input_data.input_coordinates = [(4.2, 2.4), ] + input_data.input_coordinates = {'LAT': [4.2], 'LON': [2.4]} input_data.input_years = [1981, 1982] # 2. When