Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r210 -r212 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 210) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 212) @@ -17,6 +17,7 @@ from SDToolBox.output_data import OutputData from SDToolBox.extract_data_utils import ( get_unique_indices, + get_stacked_matrix, get_single_nearest_neighbor, get_list_nearest_neighbor, get_matrix_nearest_neighbor, @@ -35,6 +36,11 @@ class ExtractData: + + netcdf_format = "netCDF4" + lon_array_pos = 1 + lat_array_pos = 0 + @staticmethod def get_era_5(directory_path: str, input_data: InputData): """Extracts a collection of netCDF4 ERA5 subsets based on the @@ -150,13 +156,14 @@ @staticmethod def get_unmasked_nn_input_data( - input_data: InputData, lat_data: np.array, lon_data: np.array, cases_dict: dict, + input_data: InputData, lat_values: np.array, lon_values: np.array ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: + nn_lat_idx, nn_lat_values = get_list_nearest_neighbor( - input_values=input_data._input_lat, reference_list=lat_data + input_values=input_data._input_lat, reference_list=lat_values, ) nn_lon_idx, nn_lon_values = get_list_nearest_neighbor( - input_values=input_data._input_lon, reference_list=lon_data + input_values=input_data._input_lon, reference_list=lon_values, ) return ( (nn_lat_idx, nn_lon_idx), @@ -165,65 +172,73 @@ @staticmethod def get_masked_nn_input_data( - input_data: InputData, ref_dataset: Dataset, mask_dir: Path, + input_data: InputData, + lat_values: np.array, + lon_values: np.array, + mask_values: np.array, ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: """Returns the masked latitude and longiuted of a givendataset and sets the resulting mask in the input_data structure. Args: input_data (InputData): Input data with coordinates. - ref_dataset (Dataset): Dataset that needs to be masked. - mask_dir (Path): Direction to the mask directory. + lat_values (np.array): Array with latitude values. + lon_values (np.array): Array with longitude values. + mask_values (np.array): Array mask based on the previous lat, lon values. Returns: Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: Index of nearest neighbors and a Tuple of corrected Latitude and Longitude """ - mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" - extracted_lat_pos = 0 - extracted_lon_pos = 1 - sea_mask, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) - possea = np.where(sea_mask.ravel() == 1)[0] + masked_positions = np.where(mask_values.ravel() == 1)[0] if not input_data.is_gridded: index, vals = get_matrix_nearest_neighbor( np.vstack((input_data._input_lat, input_data._input_lon)).T, - np.vstack((LAT.ravel()[possea], LON.ravel()[possea])).T, + np.vstack( + ( + lat_values.ravel()[masked_positions], + lon_values.ravel()[masked_positions], + ) + ).T, ) - index_abs = possea[index].squeeze() + index_abs = masked_positions[index].squeeze() if index_abs.size <= 0: return None, None # We don't need to get unique indices here. - indexes_lat_lon = np.unravel_index(index_abs, LON.shape) + indexes_lat_lon = np.unravel_index(index_abs, lon_values.shape) else: index, vals = get_matrix_nearest_neighbor( np.vstack((input_data._input_lat, input_data._input_lon)).T, - np.vstack((LAT.ravel(), LON.ravel())).T, + np.vstack((lat_values.ravel(), lon_values.ravel())).T, ) - indexes_lat_lon = np.unravel_index(index.squeeze(), LON.shape) + indexes_lat_lon = np.unravel_index(index.squeeze(), lon_values.shape) if index.shape[0] > 1: - input_data.sea_mask = sea_mask[ + input_data.sea_mask = masked_positions[ np.ix_(indexes_lat_lon[0], indexes_lat_lon[1],) ] else: input_data.sea_mask = np.array( - sea_mask[indexes_lat_lon[0], indexes_lat_lon[1],] + masked_positions[indexes_lat_lon[0], indexes_lat_lon[1],] ) - logging.info(f"Added sea mask to input_data: {input_data.sea_mask}") + logging.info(f"Added sea mask to input_data: {input_data.sea_mask}.") if len(indexes_lat_lon) == 1: return ( indexes_lat_lon, ( - np.array([vals[extracted_lat_pos]]), - np.array([vals[extracted_lon_pos]]), + np.array([vals[ExtractData.lat_array_pos]]), + np.array([vals[ExtractData.lon_array_pos]]), ), ) return ( indexes_lat_lon, - (vals[:, extracted_lat_pos].data, vals[:, extracted_lon_pos].data,), + ( + vals[:, ExtractData.lat_array_pos].data, + vals[:, ExtractData.lon_array_pos].data, + ), ) class BaseExtractor(ABC): @@ -233,10 +248,7 @@ file_year_key = "year" file_month_key = "month" file_scenario_key = "scenario" - netcdf_format = "netCDF4" stations_idx_key = "stations_idx" - lon_array_pos = 1 - lat_array_pos = 0 __file_iterator: list = [] @@ -339,9 +351,7 @@ # Set the nearest neighbors and the time reference. logging.info("Getting nearest neighbors.") nn_lat_lon_idx = self.__get_nn_lat_lon_idx( - ref_file_path=self.__first_filepath, - input_data=input_data, - output_data=output_data, + input_data=input_data, output_data=output_data, ) logging.info( f"""Nearest neighbors found: @@ -363,7 +373,7 @@ + " year {}.".format(file_entry.get(self.file_year_key)) ) # Lazy loading of the dataset. - with Dataset(filepath, "r", self.netcdf_format) as input_dataset: + with Dataset(filepath, "r", ExtractData.netcdf_format) as input_dataset: output_data.set_in_var_dict( var_name=variable_key, value=self.__get_variable_subset( @@ -446,13 +456,12 @@ ) def __get_nn_lat_lon_idx( - self, ref_file_path: Path, input_data: InputData, output_data: OutputData + self, input_data: InputData, output_data: OutputData ) -> Tuple[List[int], List[int]]: """Gets the corrected index and value for the given input coordinates. Arguments: - directory_path {Path} -- Parent directory. input_data {InputData} -- Input data. output_data {OutputData} -- OutputData @@ -465,27 +474,40 @@ logging.info(f"Getting nearest neighbors with mask = {maskneeded}.") nn_lat_lon_idx, nn_values = None, None - with Dataset(ref_file_path, "r", self.netcdf_format) as ref_dataset: - if not maskneeded: - nn_lat_lon_idx, nn_values = ExtractData.get_unmasked_nn_input_data( - input_data=input_data, - lat_data=ref_dataset.variables[self.lat_key][:], - lon_data=ref_dataset.variables[self.lon_key][:], - cases_dict=output_data.data_dict, - ) - else: - nn_lat_lon_idx, nn_values = ExtractData.get_masked_nn_input_data( - input_data=input_data, - ref_dataset=ref_dataset, - mask_dir=ref_file_path.parent.parent, - ) + if maskneeded: + mask_filepath = ( + Path(self.__first_filepath.parent.parent) / "ERA5_landsea_mask.nc" + ) + sea_mask, lon_values, lat_values = ExtractData.get_seamask_positions( + mask_filepath + ) + nn_lat_lon_idx, nn_values = ExtractData.get_masked_nn_input_data( + input_data=input_data, + lat_values=lat_values, + lon_values=lon_values, + mask_values=sea_mask, + ) + else: + lat_values, lon_values = self.__get_reference_lat_lon( + self.__first_filepath + ) + nn_lat_lon_idx, nn_values = ExtractData.get_unmasked_nn_input_data( + input_data=input_data, lat_values=lat_values, lon_values=lon_values, + ) return self.__post_process_nn_data( nn_lat_lon_idx=nn_lat_lon_idx, nn_values=nn_values, output_data=output_data, ) + def __get_reference_lat_lon(self, file_path: Path) -> Tuple[np.array, np.array]: + with Dataset(file_path, "r", ExtractData.netcdf_format) as ref_dataset: + return ( + ref_dataset.variables[self.lat_key][:], + ref_dataset.variables[self.lon_key][:], + ) + def __post_process_nn_data( self, nn_lat_lon_idx: Tuple[np.array, np.array], @@ -511,10 +533,10 @@ ) corrected_nn_lat_lon, idx_map = get_unique_indices(nn_values) output_data.data_dict[output_data.var_lat_key] = corrected_nn_lat_lon[ - self.lat_array_pos + ExtractData.lat_array_pos ] output_data.data_dict[output_data.var_lon_key] = corrected_nn_lat_lon[ - self.lon_array_pos + ExtractData.lon_array_pos ] output_data.data_dict[output_data.var_station_idx_key] = idx_map Index: trunk/SDToolBox/extract_data_utils.py =================================================================== diff -u -r208 -r212 --- trunk/SDToolBox/extract_data_utils.py (.../extract_data_utils.py) (revision 208) +++ trunk/SDToolBox/extract_data_utils.py (.../extract_data_utils.py) (revision 212) @@ -75,6 +75,20 @@ return np.array(output_idx), corrected_values +def get_stacked_matrix(lat_array: np.array, lon_array: np.array) -> np.array: + """Generates a matrix with the given values so it can be used to later + extract the nearest neighbors. + + Args: + lat_array (np.array): Array representing the Latitude axis. + lon_array (np.array): Array representing the Longitud axis. + + Returns: + np.array: Matrix of [Lat, Lon] values. + """ + return np.vstack((lat_array, lon_array)).T + + def get_matrix_nearest_neighbor( values: np.array, data_array: np.array ) -> Tuple[np.array, np.array]: