Index: trunk/SDToolBox/extract_data_utils.py =================================================================== diff -u -r207 -r208 --- trunk/SDToolBox/extract_data_utils.py (.../extract_data_utils.py) (revision 207) +++ trunk/SDToolBox/extract_data_utils.py (.../extract_data_utils.py) (revision 208) @@ -54,7 +54,7 @@ def get_list_nearest_neighbor( input_values: List[float], reference_list: List[float], -) -> Tuple[List[int], List[float]]: +) -> Tuple[np.array, List[float]]: """Sets the nearest neighbor for all the elements given in the points_list. @@ -72,7 +72,7 @@ idx, value = get_single_nearest_neighbor(point, reference_list) corrected_values.append(value) output_idx.append(idx) - return output_idx, corrected_values + return np.array(output_idx), corrected_values def get_matrix_nearest_neighbor( Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r207 -r208 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 207) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 208) @@ -180,51 +180,53 @@ Index of nearest neighbors and a Tuple of corrected Latitude and Longitude """ mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" - extracted_lat_pos = 1 - extracted_lon_pos = 0 + extracted_lat_pos = 0 + extracted_lon_pos = 1 sea_mask, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) possea = np.where(sea_mask.ravel() == 1)[0] if not input_data.is_gridded: index, vals = get_matrix_nearest_neighbor( - np.vstack((input_data._input_lon, input_data._input_lat)).T, - np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, + np.vstack((input_data._input_lat, input_data._input_lon)).T, + np.vstack((LAT.ravel()[possea], LON.ravel()[possea])).T, ) index_abs = possea[index].squeeze() if index_abs.size <= 0: return None, None - indexes_lat_lon = np.unravel_index(index_abs, LON.shape) - unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + indexes_lat_lon, _ = get_unique_indices( + np.unravel_index(index_abs, LON.shape) + ) else: index, vals = get_matrix_nearest_neighbor( - np.vstack((input_data._input_lon, input_data._input_lat)).T, - np.vstack((LON.ravel(), LAT.ravel())).T, + np.vstack((input_data._input_lat, input_data._input_lon)).T, + np.vstack((LAT.ravel(), LON.ravel())).T, ) - indexes_lat_lon = np.unravel_index(index.squeeze(), LON.shape) - unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + indexes_lat_lon, _ = get_unique_indices( + np.unravel_index(index.squeeze(), LON.shape) + ) - if index.shape[extracted_lon_pos] > 1: + if index.shape[0] > 1: input_data.sea_mask = sea_mask[ - np.ix_(np.sort(indexes_lat_lon[0]), np.sort(indexes_lat_lon[1]),) + np.ix_(indexes_lat_lon[0], indexes_lat_lon[1],) ] else: input_data.sea_mask = np.array( sea_mask[indexes_lat_lon[0], indexes_lat_lon[1],] ) logging.info(f"Added sea mask to input_data: {input_data.sea_mask}") - if len(unique_values.shape) == 1: + if len(indexes_lat_lon) == 1: return ( indexes_lat_lon, ( - np.array([unique_values[extracted_lat_pos]]), - np.array([unique_values[extracted_lon_pos]]), + np.array([vals[extracted_lat_pos]]), + np.array([vals[extracted_lon_pos]]), ), ) return ( indexes_lat_lon, - (unique_values[:, extracted_lat_pos], unique_values[:, extracted_lon_pos],), + (vals[:, extracted_lat_pos], vals[:, extracted_lon_pos],), ) class BaseExtractor(ABC): @@ -523,7 +525,7 @@ ] output_data.data_dict[output_data.var_station_idx_key] = idx_map logging.info( - f"""Corrected nn indices: + f"""Corrected nn indices: Lat: {corrected_nn_lat_lon_idx[0].astype(str)} Lon: {corrected_nn_lat_lon_idx[1].astype(str)}""" ) Index: trunk/SDToolBox/input_data.py =================================================================== diff -u -r195 -r208 --- trunk/SDToolBox/input_data.py (.../input_data.py) (revision 195) +++ trunk/SDToolBox/input_data.py (.../input_data.py) (revision 208) @@ -164,7 +164,9 @@ self.__get_corrected_longitude(lon) for lon in self._input_lon ] if self.is_gridded: - self._input_lat = sorted(set(self._input_lat)) + # Latitudes go from MAX to MIN + self._input_lat = sorted(set(self._input_lat))[::-1] + # Longitudes go from MIN to MAX self._input_lon = list(sorted(set(self._input_lon))) assert len(self._input_lat) == len( self._input_lon