Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r205 -r206 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 205) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 206) @@ -152,6 +152,8 @@ file_scenario_key = "scenario" netcdf_format = "netCDF4" stations_idx_key = "stations_idx" + lon_array_pos = 1 + lat_array_pos = 0 __file_iterator: list = [] @@ -253,27 +255,29 @@ # Set the nearest neighbors and the time reference. logging.info("Getting nearest neighbors.") - nn_idx, nn_values = self.__get_nearest_neighbors_lat_lon( + nn_lat_lon_idx, nn_values = self.__get_nearest_neighbors_lat_lon( ref_file_path=self.__first_filepath, input_data=input_data, cases_dict=output_data.data_dict, ) logging.info( - f"Nearest neighbors: \nLat:{nn_values[0].astype(str)} \nLon:{nn_values[1].astype(str)}" + f"Nearest neighbors: \nLat:{nn_values[self.lat_array_pos].astype(str)} \nLon:{nn_values[self.lon_array_pos].astype(str)}" ) - nn_idx, idx_map = self.__get_unique_indices(nn_idx) + # Remove duplicates and store lat/lon values and their stations. + nn_lat_lon_idx, idx_map = self.__get_unique_indices(nn_lat_lon_idx) output_data.data_dict[output_data.var_lat_key] = [ value - for idx, value in enumerate(nn_values[0].tolist()) + for idx, value in enumerate(nn_values[self.lat_array_pos].tolist()) if idx in idx_map ] output_data.data_dict[output_data.var_lon_key] = [ value - for idx, value in enumerate(nn_values[1].tolist()) + for idx, value in enumerate(nn_values[self.lon_array_pos].tolist()) if idx in idx_map ] output_data.data_dict[output_data.var_station_idx_key] = idx_map + logging.info("Getting time references.") time_entry_refs = self.__get_time_ref_group() @@ -296,7 +300,7 @@ input_data, input_dataset, variable_key, - nn_idx, + nn_lat_lon_idx, ), ) # Set the time if the file needs to be considered. @@ -319,20 +323,24 @@ return output_data def __get_unique_indices( - self, nn_idx: Tuple[list, list] + self, nn_lat_lon_idx: Tuple[list, list] ) -> Tuple[Tuple[list, list], list]: # Get unique indices. - if not isinstance(nn_idx[0], (list, np.ndarray)): - nn_idx = [nn_idx[0]], [nn_idx[1]] - combined_array = np.stack(nn_idx, axis=-1) - unique_nn, unique_idx = np.unique(combined_array, return_index=True, axis=0) + if not isinstance(nn_lat_lon_idx[0], (list, np.ndarray)): + nn_lat_lon_idx = [nn_lat_lon_idx[0]], [nn_lat_lon_idx[1]] + combined_array = np.stack(nn_lat_lon_idx, axis=1) + _, unique_idx = np.unique(combined_array, return_index=True, axis=0) + unique_combinations = combined_array[sorted(unique_idx)] new_nn_idx = ( - [nn_idx[0][i] for i in unique_idx], - [nn_idx[1][i] for i in unique_idx], + unique_combinations[:, 0], + unique_combinations[:, 1], ) return ( new_nn_idx, - [unique_nn.tolist().index(nn.tolist()) for nn in combined_array], + [ + unique_combinations.tolist().index(nn.tolist()) + for nn in combined_array + ], ) def __get_time_ref_group(self) -> List[str]: @@ -494,9 +502,9 @@ return index_found, value_found @staticmethod - def get_nearest_neighbor_extended( - values, data_array: np.array - ) -> Tuple[int, int]: + def get_nearest_neighbor_extended_lon_lat( + values: np.array, data_array: np.array + ) -> Tuple[np.array, np.array]: """ search for nearest decimal degree in an array of decimal degrees and return the index. @@ -528,44 +536,57 @@ Index of nearest neighbors and a Tuple of corrected Latitude and Longitude """ mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" - + extracted_lat_pos = 1 + extracted_lon_pos = 0 sea_mask, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) possea = np.where(sea_mask.ravel() == 1)[0] if not input_data.is_gridded: - index, vals = self.get_nearest_neighbor_extended( + index, vals = self.get_nearest_neighbor_extended_lon_lat( np.vstack((input_data._input_lon, input_data._input_lat)).T, np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, ) index_abs = possea[index].squeeze() if index_abs.size <= 0: return None, None - indexes = np.unravel_index(index_abs, LON.shape) + indexes_lat_lon = np.unravel_index(index_abs, LON.shape) unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) - nn_idx = indexes[0], indexes[1] else: - index, vals = self.get_nearest_neighbor_extended( + index, vals = self.get_nearest_neighbor_extended_lon_lat( np.vstack((input_data._input_lon, input_data._input_lat)).T, np.vstack((LON.ravel(), LAT.ravel())).T, ) - indexes = np.unravel_index(index.squeeze(), LON.shape) - nn_idx = indexes[0], indexes[1] + indexes_lat_lon = np.unravel_index(index.squeeze(), LON.shape) unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) - if index.shape[0] > 1: + + if index.shape[extracted_lon_pos] > 1: input_data.sea_mask = sea_mask[ - np.ix_(np.sort(indexes[0]), np.sort(indexes[1])) + np.ix_( + np.sort(indexes_lat_lon[0]), np.sort(indexes_lat_lon[1]), + ) ] else: - input_data.sea_mask = np.array(sea_mask[indexes[0], indexes[1]]) + input_data.sea_mask = np.array( + sea_mask[indexes_lat_lon[0], indexes_lat_lon[1],] + ) logging.info(f"Added sea mask to input_data: {input_data.sea_mask}") if len(unique_values.shape) == 1: return ( - nn_idx, - (np.array([unique_values[1]]), np.array([unique_values[0]])), + indexes_lat_lon, + ( + np.array([unique_values[extracted_lat_pos]]), + np.array([unique_values[extracted_lon_pos]]), + ), ) - return nn_idx, (unique_values[:, 1], unique_values[:, 0]) + return ( + indexes_lat_lon, + ( + unique_values[:, extracted_lat_pos], + unique_values[:, extracted_lon_pos], + ), + ) @staticmethod def __get_case_subset(