Index: trunk/SDToolBox/extract_data_utils.py =================================================================== diff -u --- trunk/SDToolBox/extract_data_utils.py (revision 0) +++ trunk/SDToolBox/extract_data_utils.py (revision 207) @@ -0,0 +1,95 @@ +import os +import logging + +from typing import List, Dict, Tuple, Any +from pathlib import Path +from abc import ABC, abstractmethod +import itertools + +from datetime import datetime, timedelta + +from netCDF4 import Dataset +from sklearn.neighbors import BallTree as BallTree +import numpy as np + + +def get_unique_indices(idx_array: Tuple[list, list]) -> Tuple[Tuple[list, list], list]: + """Given a list of indices it gets the unique list of them. + + Args: + idx_array (Tuple[list, list]): Array containing two lists of indices. + + Returns: + Tuple[Tuple[list, list], list]: Unique list of indices and a mapping of the original list with its corresponding unique index. + """ + # Get unique indices. + if not isinstance(idx_array[0], (list, np.ndarray)): + idx_array = [idx_array[0]], [idx_array[1]] + combined_array = np.stack(idx_array, axis=1) + _, unique_idx = np.unique(combined_array, return_index=True, axis=0) + unique_combinations = combined_array[sorted(unique_idx)] + new_unique_idx = ( + unique_combinations[:, 0], + unique_combinations[:, 1], + ) + return ( + new_unique_idx, + [unique_combinations.tolist().index(nn.tolist()) for nn in combined_array], + ) + + +def get_single_nearest_neighbor(value, data_array) -> Tuple[int, int]: + """ + search for nearest decimal degree in an array of decimal + degrees and return the index. + np.argmin returns the indices of minium value along an axis. + so subtract value from all values in data_array, + take absolute value and find index of minium. + """ + logging.info("Searching nearest neighbor for nearest decimal degree.") + index_found = (np.abs(data_array - value)).argmin() + value_found = data_array[index_found] + return index_found, value_found + + +def get_list_nearest_neighbor( + input_values: List[float], reference_list: List[float], +) -> Tuple[List[int], List[float]]: + """Sets the nearest neighbor for all the elements + given in the points_list. + + Arguments: + input_values {List[float]} -- List of elements to correct. + reference_list {List[float]} -- Available neighbor list. + + Returns: + Tuple[List[int], List[float]] -- Tuple with Indices of the nearest neighbors positions and said values. + """ + output_idx = [] + + corrected_values: List[float] = [] + for point in input_values: + idx, value = get_single_nearest_neighbor(point, reference_list) + corrected_values.append(value) + output_idx.append(idx) + return output_idx, corrected_values + + +def get_matrix_nearest_neighbor( + values: np.array, data_array: np.array +) -> Tuple[np.array, np.array]: + """ + search for nearest decimal degree in an array of decimal + degrees and return the index. + np.argmin returns the indices of minium value along an axis. + so subtract value from all values in data_array, + take absolute value and find index of minium. + """ + logging.info("Searching nearest neighbor applying BallTree.") + if data_array.size == 0: + logging.error("No valid input array to find nearest neighbors.") + kd = BallTree(data_array, leaf_size=10) + # k=2 nearest neighbors where k1 = identity + _, index_found = kd.query(values, k=1) + value_found = data_array[index_found.squeeze(), :] + return index_found, value_found Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r206 -r207 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 206) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 207) @@ -15,7 +15,12 @@ from SDToolBox import output_messages as om from SDToolBox.input_data import InputData from SDToolBox.output_data import OutputData - +from SDToolBox.extract_data_utils import ( + get_unique_indices, + get_single_nearest_neighbor, + get_list_nearest_neighbor, + get_matrix_nearest_neighbor, +) from datetime import datetime, timedelta from netCDF4 import Dataset @@ -143,6 +148,85 @@ mask[mask == -1] = 0 return mask, LON, LAT + @staticmethod + def get_unmasked_nn_input_data( + input_data: InputData, lat_data: np.array, lon_data: np.array, cases_dict: dict, + ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: + nn_lat_idx, nn_lat_values = get_list_nearest_neighbor( + input_values=input_data._input_lat, reference_list=lat_data + ) + nn_lon_idx, nn_lon_values = get_list_nearest_neighbor( + input_values=input_data._input_lon, reference_list=lon_data + ) + return ( + (nn_lat_idx, nn_lon_idx), + (np.array(nn_lat_values), np.array(nn_lon_values)), + ) + + @staticmethod + def get_masked_nn_input_data( + input_data: InputData, ref_dataset: Dataset, mask_dir: Path, + ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: + """Returns the masked latitude and longiuted of a givendataset and sets + the resulting mask in the input_data structure. + + Args: + input_data (InputData): Input data with coordinates. + ref_dataset (Dataset): Dataset that needs to be masked. + mask_dir (Path): Direction to the mask directory. + + Returns: + Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: + Index of nearest neighbors and a Tuple of corrected Latitude and Longitude + """ + mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" + extracted_lat_pos = 1 + extracted_lon_pos = 0 + sea_mask, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) + possea = np.where(sea_mask.ravel() == 1)[0] + + if not input_data.is_gridded: + index, vals = get_matrix_nearest_neighbor( + np.vstack((input_data._input_lon, input_data._input_lat)).T, + np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, + ) + index_abs = possea[index].squeeze() + + if index_abs.size <= 0: + return None, None + indexes_lat_lon = np.unravel_index(index_abs, LON.shape) + unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + + else: + index, vals = get_matrix_nearest_neighbor( + np.vstack((input_data._input_lon, input_data._input_lat)).T, + np.vstack((LON.ravel(), LAT.ravel())).T, + ) + indexes_lat_lon = np.unravel_index(index.squeeze(), LON.shape) + unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + + if index.shape[extracted_lon_pos] > 1: + input_data.sea_mask = sea_mask[ + np.ix_(np.sort(indexes_lat_lon[0]), np.sort(indexes_lat_lon[1]),) + ] + else: + input_data.sea_mask = np.array( + sea_mask[indexes_lat_lon[0], indexes_lat_lon[1],] + ) + logging.info(f"Added sea mask to input_data: {input_data.sea_mask}") + if len(unique_values.shape) == 1: + return ( + indexes_lat_lon, + ( + np.array([unique_values[extracted_lat_pos]]), + np.array([unique_values[extracted_lon_pos]]), + ), + ) + return ( + indexes_lat_lon, + (unique_values[:, extracted_lat_pos], unique_values[:, extracted_lon_pos],), + ) + class BaseExtractor(ABC): file_var_key = "variable" file_key_key = "key" @@ -184,7 +268,7 @@ raise Exception(om.error_needs_to_be_in_subclass) @property - def __first_filepath(self) -> str: + def __first_filepath(self) -> Path: return self.__file_iterator[0][self.file_fpath_key] @property @@ -255,29 +339,17 @@ # Set the nearest neighbors and the time reference. logging.info("Getting nearest neighbors.") - nn_lat_lon_idx, nn_values = self.__get_nearest_neighbors_lat_lon( + nn_lat_lon_idx = self.__get_nn_lat_lon_idx( ref_file_path=self.__first_filepath, input_data=input_data, - cases_dict=output_data.data_dict, + output_data=output_data, ) logging.info( - f"Nearest neighbors: \nLat:{nn_values[self.lat_array_pos].astype(str)} \nLon:{nn_values[self.lon_array_pos].astype(str)}" + f"""Nearest neighbors found: + Lat: {np.array(output_data.data_dict[output_data.var_lat_key]).astype(str)} + Lon: {np.array(output_data.data_dict[output_data.var_lon_key]).astype(str)} + Stations (for non-gridded): {np.array(output_data.data_dict[output_data.var_station_idx_key]).astype(str)}""" ) - - # Remove duplicates and store lat/lon values and their stations. - nn_lat_lon_idx, idx_map = self.__get_unique_indices(nn_lat_lon_idx) - output_data.data_dict[output_data.var_lat_key] = [ - value - for idx, value in enumerate(nn_values[self.lat_array_pos].tolist()) - if idx in idx_map - ] - output_data.data_dict[output_data.var_lon_key] = [ - value - for idx, value in enumerate(nn_values[self.lon_array_pos].tolist()) - if idx in idx_map - ] - output_data.data_dict[output_data.var_station_idx_key] = idx_map - logging.info("Getting time references.") time_entry_refs = self.__get_time_ref_group() @@ -322,27 +394,6 @@ return output_data - def __get_unique_indices( - self, nn_lat_lon_idx: Tuple[list, list] - ) -> Tuple[Tuple[list, list], list]: - # Get unique indices. - if not isinstance(nn_lat_lon_idx[0], (list, np.ndarray)): - nn_lat_lon_idx = [nn_lat_lon_idx[0]], [nn_lat_lon_idx[1]] - combined_array = np.stack(nn_lat_lon_idx, axis=1) - _, unique_idx = np.unique(combined_array, return_index=True, axis=0) - unique_combinations = combined_array[sorted(unique_idx)] - new_nn_idx = ( - unique_combinations[:, 0], - unique_combinations[:, 1], - ) - return ( - new_nn_idx, - [ - unique_combinations.tolist().index(nn.tolist()) - for nn in combined_array - ], - ) - def __get_time_ref_group(self) -> List[str]: return [ file_entry @@ -395,39 +446,90 @@ axis=int(not input_data.is_gridded), ) - def __get_nearest_neighbors_lat_lon( - self, ref_file_path: str, input_data: InputData, cases_dict: dict - ) -> Tuple[Tuple[List[int], List[int]], Tuple[np.array, np.array]]: + def __get_nn_lat_lon_idx( + self, ref_file_path: Path, input_data: InputData, output_data: OutputData + ) -> Tuple[List[int], List[int]]: """Gets the corrected index and value for the given input coordinates. Arguments: - directory_path {str} -- Parent directory. + directory_path {Path} -- Parent directory. input_data {InputData} -- Input data. - cases_dict {Dict[str,str]} - -- Dictionary with all values that need format. + output_data {OutputData} -- OutputData Returns: - Tuple[Tuple[List[int], List[int]], Tuple[np.array, np.array] + Tuple[List[int], List[int]] -- Indices of nearest neighbors. """ # Extract index and value for all input lat, lon. maskneeded = self.mask_is_needed(input_data.input_variables) logging.info(f"Getting nearest neighbors with mask = {maskneeded}.") + nn_lat_lon_idx, nn_values = None, None + with Dataset(ref_file_path, "r", self.netcdf_format) as ref_dataset: if not maskneeded: - return self.get_unmasked_nn_lat_lon( + nn_lat_lon_idx, nn_values = ExtractData.get_unmasked_nn_input_data( input_data=input_data, - ref_dataset=ref_dataset, - cases_dict=cases_dict, + lat_data=ref_dataset.variables[self.lat_key][:], + lon_data=ref_dataset.variables[self.lon_key][:], + cases_dict=output_data.data_dict, ) else: - return self.get_masked_nn_lat_lon( + nn_lat_lon_idx, nn_values = ExtractData.get_masked_nn_input_data( input_data=input_data, ref_dataset=ref_dataset, mask_dir=ref_file_path.parent.parent, ) + return self.__post_process_nn_data( + nn_lat_lon_idx=nn_lat_lon_idx, + nn_values=nn_values, + output_data=output_data, + ) + + def __post_process_nn_data( + self, + nn_lat_lon_idx: Tuple[np.array, np.array], + nn_values: Tuple[np.array, np.array], + output_data: OutputData, + ) -> Tuple[np.array, np.array]: + """Corrects the generated NN data removing duplicates and + setting the end results in a dictionary using the lat/lon keys. + + Args: + nn_lat_lon_idx (Tuple[np.array, np.array]): Index with generated NN lats and lons. + nn_values (Tuple[np.array, np.array]): Nn values. + output_data (OutputData): OutputData object where to save the lat/lon and stations indices. + + Returns: + Tuple[np.array, np.array]: Returns the corrected index (LAT, LON) of nearest neighbors. + """ + # Remove duplicates and store lat/lon values and their stations. + logging.info( + f"""Correcting nn indices to remove duplicates: + Lat: {nn_lat_lon_idx[0].astype(str)} + Lon: {nn_lat_lon_idx[1].astype(str)}""" + ) + corrected_nn_lat_lon_idx, idx_map = get_unique_indices(nn_lat_lon_idx) + output_data.data_dict[output_data.var_lat_key] = [ + value + for idx, value in enumerate(nn_values[self.lat_array_pos].tolist()) + if idx in idx_map + ] + output_data.data_dict[output_data.var_lon_key] = [ + value + for idx, value in enumerate(nn_values[self.lon_array_pos].tolist()) + if idx in idx_map + ] + output_data.data_dict[output_data.var_station_idx_key] = idx_map + logging.info( + f"""Corrected nn indices: + Lat: {corrected_nn_lat_lon_idx[0].astype(str)} + Lon: {corrected_nn_lat_lon_idx[1].astype(str)}""" + ) + + return corrected_nn_lat_lon_idx + def _get_filtered_dict( self, values_selected: List[str], values_dict: Dict[str, str] ) -> Dict[str, str]: @@ -443,152 +545,7 @@ logging.info("Filtering variable dictionary.") return {k: v for k, v in values_dict.items() if k in values_selected} - def get_unmasked_nn_lat_lon( - self, input_data: InputData, ref_dataset: Dataset, cases_dict: dict - ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: - nn_lat_idx, nn_lat_values = self._set_nn( - input_values=input_data._input_lat, - reference_list=ref_dataset.variables[self.lat_key][:], - is_gridded=input_data.is_gridded, - ) - nn_lon_idx, nn_lon_values = self._set_nn( - input_values=input_data._input_lon, - reference_list=ref_dataset.variables[self.lon_key][:], - is_gridded=input_data.is_gridded, - ) - return ( - (nn_lat_idx, nn_lon_idx), - (np.array(nn_lat_values), np.array(nn_lon_values)), - ) - - def _set_nn( - self, - input_values: List[float], - reference_list: List[float], - is_gridded: bool, - ) -> Tuple[List[int], List[float]]: - """Sets the nearest neighbor for all the elements - given in the points_list. - - Arguments: - input_values {List[float]} -- List of elements to correct. - reference_list {List[float]} -- Available neighbor list. - is_gridded {bool} -- Needs to extract nn as gridded dataset. - - Returns: - Tuple[List[int], List[float]] -- Tuple with Indices of the nearest neighbors positions and said values. - """ - output_idx = [] - - corrected_values: List[float] = [] - for point in input_values: - idx, value = self.get_nearest_neighbor(point, reference_list) - corrected_values.append(value) - output_idx.append(idx) - return output_idx, corrected_values - @staticmethod - def get_nearest_neighbor(value, data_array) -> Tuple[int, int]: - """ - search for nearest decimal degree in an array of decimal - degrees and return the index. - np.argmin returns the indices of minium value along an axis. - so subtract value from all values in data_array, - take absolute value and find index of minium. - """ - logging.info("Searching nearest neighbor for nearest decimal degree.") - index_found = (np.abs(data_array - value)).argmin() - value_found = data_array[index_found] - return index_found, value_found - - @staticmethod - def get_nearest_neighbor_extended_lon_lat( - values: np.array, data_array: np.array - ) -> Tuple[np.array, np.array]: - """ - search for nearest decimal degree in an array of decimal - degrees and return the index. - np.argmin returns the indices of minium value along an axis. - so subtract value from all values in data_array, - take absolute value and find index of minium. - """ - logging.info("Searching nearest neighbor applying BallTree.") - if data_array.size == 0: - logging.error("No valid input array to find nearest neighbors.") - kd = BallTree(data_array, leaf_size=10) - # k=2 nearest neighbors where k1 = identity - _, index_found = kd.query(values, k=1) - value_found = data_array[index_found.squeeze(), :] - return index_found, value_found - - def get_masked_nn_lat_lon( - self, input_data: InputData, ref_dataset: Dataset, mask_dir: Path, - ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: - """Returns the masked latitude and longiuted of a givendataset. - - Args: - input_data (InputData): Input data with coordinates. - ref_dataset (Dataset): Dataset that needs to be masked. - mask_dir (Path): Direction to the mask directory. - - Returns: - Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: - Index of nearest neighbors and a Tuple of corrected Latitude and Longitude - """ - mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" - extracted_lat_pos = 1 - extracted_lon_pos = 0 - sea_mask, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) - possea = np.where(sea_mask.ravel() == 1)[0] - - if not input_data.is_gridded: - index, vals = self.get_nearest_neighbor_extended_lon_lat( - np.vstack((input_data._input_lon, input_data._input_lat)).T, - np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, - ) - index_abs = possea[index].squeeze() - - if index_abs.size <= 0: - return None, None - indexes_lat_lon = np.unravel_index(index_abs, LON.shape) - unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) - - else: - index, vals = self.get_nearest_neighbor_extended_lon_lat( - np.vstack((input_data._input_lon, input_data._input_lat)).T, - np.vstack((LON.ravel(), LAT.ravel())).T, - ) - indexes_lat_lon = np.unravel_index(index.squeeze(), LON.shape) - unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) - - if index.shape[extracted_lon_pos] > 1: - input_data.sea_mask = sea_mask[ - np.ix_( - np.sort(indexes_lat_lon[0]), np.sort(indexes_lat_lon[1]), - ) - ] - else: - input_data.sea_mask = np.array( - sea_mask[indexes_lat_lon[0], indexes_lat_lon[1],] - ) - logging.info(f"Added sea mask to input_data: {input_data.sea_mask}") - if len(unique_values.shape) == 1: - return ( - indexes_lat_lon, - ( - np.array([unique_values[extracted_lat_pos]]), - np.array([unique_values[extracted_lon_pos]]), - ), - ) - return ( - indexes_lat_lon, - ( - unique_values[:, extracted_lat_pos], - unique_values[:, extracted_lon_pos], - ), - ) - - @staticmethod def __get_case_subset( dataset: Dataset, variable_name: str,