Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r193 -r196 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 193) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 196) @@ -110,6 +110,18 @@ Returns: Tuple[np.array, np.array, np.array]: Sea positions, LON, LAT """ + sea_mask, LON, LAT = ExtractData.get_seamask(mask_filepath) + if np.where(sea_mask.ravel() == 1)[0].size == 0: + logging.error( + f"No sea positions were found at the extracted mask from {mask_filepath}." + ) + # positions where we are in the sea + logging.info(f"Generating sea mask with mask: {sea_mask}") + + return sea_mask, LON, LAT + + @staticmethod + def get_seamask(mask_filepath: Path,) -> Tuple[np.array, np.array, np.array]: logging.info(f"Getting masked file from {mask_filepath}.") with Dataset(mask_filepath, "r", "netCDF4") as sea_mask: logging.info(f"Extracting variable lsm from mask file.") @@ -126,14 +138,10 @@ LON = np.concatenate((LON[:, loni + 1 :], LON[:, 0 : loni + 1]), axis=1) mask = np.concatenate((mask[:, loni + 1 :], mask[:, 0 : loni + 1]), axis=1) - # positions where we are in the sea - logging.info(f"Generating sea mask with mask: {mask}") - possea = np.where(mask.ravel() == 0.0)[0] - if possea.size == 0: - logging.error( - f"No sea positions were found at the extracted mask from {mask_filepath}." - ) - return possea, LON, LAT + mask[mask != 0] = -1 + mask[mask == 0] = 1 + mask[mask == -1] = 0 + return mask, LON, LAT class BaseExtractor(ABC): file_var_key = "variable" @@ -520,25 +528,41 @@ Index of nearest neighbors and a Tuple of corrected Latitude and Longitude """ mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" - possea, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) - index, vals = self.get_nearest_neighbor_extended( - np.vstack((input_data._input_lon, input_data._input_lat)).T, - np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, - ) - index_abs = possea[index].squeeze() - if index_abs.size <= 0: - return None, None - indexes = np.unravel_index(index_abs, LON.shape) - unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) - nn_idx = indexes[1], indexes[0] - if len(unique_values.shape) == 1: - return ( - nn_idx, - (np.array([unique_values[1]]), np.array([unique_values[0]])), + sea_mask, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) + possea = np.where(sea_mask.ravel() == 1)[0] + + if not input_data.is_gridded: + index, vals = self.get_nearest_neighbor_extended( + np.vstack((input_data._input_lon, input_data._input_lat)).T, + np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, ) - return nn_idx, (unique_values[:, 1], unique_values[:, 0]) + index_abs = possea[index].squeeze() + if index_abs.size <= 0: + return None, None + indexes = np.unravel_index(index_abs, LON.shape) + unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + nn_idx = indexes[1], indexes[0] + if len(unique_values.shape) == 1: + return ( + nn_idx, + (np.array([unique_values[1]]), np.array([unique_values[0]])), + ) + else: + index, vals = self.get_nearest_neighbor_extended( + np.vstack((input_data._input_lon, input_data._input_lat)).T, + np.vstack((LON.ravel(), LAT.ravel())).T, + ) + indexes = np.unravel_index(index.squeeze(), LON.shape) + nn_idx = indexes[1], indexes[0] + unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + input_data.sea_mask = sea_mask[ + np.ix_(sorted(indexes[0]), sorted(indexes[1])) + ] + logging.info("Added sea mask to input_data.") + return nn_idx, (unique_values[:, 1], unique_values[:, 0]) + @staticmethod def __get_case_subset( dataset: Dataset, @@ -583,6 +607,11 @@ + "[{}]".format(", ".join([str(coord) for coord in lat])) + "[{}]".format(", ".join([str(coord) for coord in lon])) ) + if hasattr(input_data, "sea_mask"): + return ( + np.ma.getdata(dataset[variable_name][:, lat, lon]) + * input_data.sea_mask + ) return np.ma.getdata(dataset[variable_name][:, lat, lon]) return np.ma.getdata(dataset[variable_name][:, lat[0], lon[0]]) Index: trunk/tests/test_acceptance.py =================================================================== diff -u -r193 -r196 --- trunk/tests/test_acceptance.py (.../test_acceptance.py) (revision 193) +++ trunk/tests/test_acceptance.py (.../test_acceptance.py) (revision 196) @@ -229,11 +229,11 @@ self, variable_key: str ): steplon = 0.5 - lonl = 3 + lonl = 5 lonr = 7.5 steplat = 0.5 latl = 52.5 - latu = 57 + latu = 55 coordsBOX = { "LON": np.arange(lonl, lonr + steplon, steplon).tolist(), "LAT": np.arange(latl, latu + steplat, steplat).tolist(),