Index: trunk/tests/test_output_data.py =================================================================== diff -u -r137 -r138 --- trunk/tests/test_output_data.py (.../test_output_data.py) (revision 137) +++ trunk/tests/test_output_data.py (.../test_output_data.py) (revision 138) @@ -204,8 +204,8 @@ latmin = 50 latmax = 62 dlat = 0.5 - lon = np.arange(latmin, latmax + dlat, dlat) - lat = np.arange(lonmin, lonmax + dlon, dlon) + lon = np.arange(latmin, latmax + dlat, dlat)[4:8] + lat = np.arange(lonmin, lonmax + dlon, dlon)[4:8] input_data.input_coordinates = {"LON": lon, "LAT": lat} input_data.is_gridded = False input_data.input_years = [1981, 1982] @@ -237,7 +237,7 @@ output_test_data = TestUtils.get_local_test_data_dir("output_data") input_data = InputData() input_data.input_variables = ["swh"] - input_data.input_coordinates = [(4.2, 2.4), (2.5, 42), (4.2, 3.6)] + input_data.input_coordinates = {"LON": [4.2, 2.5, 4.3], "LAT": [2.4, 42, 3.6]} input_data.is_gridded = True input_data.input_years = [1981, 1982] netcdf_filepath = None @@ -267,9 +267,7 @@ # 1. Given input_data = InputData() input_data.input_variables = ["swh"] - input_data.input_coordinates = [ - (4.2, 2.4), - ] + input_data.input_coordinates = {"LON"} input_data.input_years = [1981, 1982] output_data = None generated_output = None @@ -293,9 +291,10 @@ input_data = InputData() input_data.input_variables = ["swh"] - input_data.input_coordinates = [ - (4.2, 2.4), - ] + input_data.input_coordinates = { + input_data.longitude_key: [50.5], + input_data.latitude_key: [52.5], + } input_data.input_years = [1981, 1982] return_value = None @@ -330,9 +329,17 @@ input_data.input_variables = ["swh"] input_data.input_years = [1981, 1982] + lonmin = -5 + lonmax = 11 + dlon = 0.5 + latmin = 50 + latmax = 62 + dlat = 0.5 + lon = np.arange(latmin, latmax + dlat, dlat) + lat = np.arange(lonmin, lonmax + dlon, dlon) input_data.input_coordinates = { - "LON": [(4.2, 2.6), (14, 4.2), (42, 24), (4.2, 2.4)], - "LAT": [(4.2, 2.6), (14, 4.2), (42, 24), (4.2, 2.4)], + input_data.longitude_key: lon, + input_data.latitude_key: lat, } input_data.is_gridded = False Index: trunk/tests/test_extract_data.py =================================================================== diff -u -r137 -r138 --- trunk/tests/test_extract_data.py (.../test_extract_data.py) (revision 137) +++ trunk/tests/test_extract_data.py (.../test_extract_data.py) (revision 138) @@ -15,27 +15,50 @@ class Test_get_era5: - @pytest.mark.systemtest - def test_given_list_of_coordinates_then_subset_is_extracted(self): - # 1. Given - # When using local data you can just replace the comment in these lines - dir_test_data = Path(TestUtils.get_local_test_data_dir("chunked_data")) - - input_data = InputData() + @pytest.fixture + def default_lon_lat(self) -> Tuple[np.array, np.array]: lonmin = -5 lonmax = 11 dlon = 0.5 latmin = 50 latmax = 62 dlat = 0.5 - lon = np.arange(latmin, latmax + dlat, dlat) - lat = np.arange(lonmin, lonmax + dlon, dlon) + lon = np.arange(latmin, latmax + dlat, dlat)[4:8] + lat = np.arange(lonmin, lonmax + dlon, dlon)[4:8] + return lon, lat + @pytest.mark.integrationtest + def test_verify_seamask_has_sea_positions( + self, default_lon_lat: Tuple[np.array, np.array] + ): + # 1. Given + # dir_test_data = Path(TestUtils.get_local_test_data_dir("chunked_data")) + dir_test_data = Path("P:\\metocean-data\\open\\ERA5\\data\\Global") + mask_filepath = dir_test_data / "ERA5_landsea_mask.nc" + + # 2. Verify netcdf + assert mask_filepath.is_file() + result_possea, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) + + # 3. Verify mask + assert result_possea.size > 0, "No sea positions found in mask." + assert LON.size > 0, "No masked lon coordinates." + assert LAT.size > 0, "No masked lat coordinates." + + @pytest.mark.systemtest + def test_given_list_of_coordinates_then_subset_is_extracted( + self, default_lon_lat: Tuple[np.array, np.array] + ): + # 1. Given + # When using local data you can just replace the comment in these lines + dir_test_data = Path(TestUtils.get_local_test_data_dir("chunked_data")) + + input_data = InputData() input_data.input_variables = ["swh"] input_data.input_years = [1981, 1982] input_data.input_coordinates = { - input_data.longitude_key: lon, - input_data.latitude_key: lat, + input_data.longitude_key: default_lon_lat[0], + input_data.latitude_key: default_lon_lat[1], } # 2. When @@ -74,7 +97,7 @@ dlat = 0.5 lon = np.arange(latmin, latmax + dlat, dlat) lat = np.arange(lonmin, lonmax + dlon, dlon) - input_data.input_coordinates = {"LAT": lat[1:2], "LON": lon[1:2]} + input_data.input_coordinates = {"LAT": [lat[1:2]], "LON": lon[1:2]} input_data.input_years = [1981, 1982] # 2. When Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r136 -r138 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 136) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 138) @@ -8,6 +8,7 @@ import logging from typing import List, Dict, Tuple, Any +from pathlib import Path from abc import ABC, abstractmethod import itertools @@ -96,6 +97,44 @@ "EARTH (scenarios)": earth_extractor.possible_scenarios, } + @staticmethod + def get_seamask_positions( + mask_filepath: Path, + ) -> Tuple[np.array, np.array, np.array]: + """Gets the seamask positions for a given filepath. Returns + the masked LON and LAT. + + Args: + mask_filepath (Path): Filepath to the netcdf mask. + + Returns: + Tuple[np.array, np.array, np.array]: Sea positions, LON, LAT + """ + logging.info(f"Getting masked file from {mask_filepath}.") + with Dataset(mask_filepath, "r", "netCDF4") as sea_mask: + logging.info(f"Extracting variable lsm from mask file.") + mask = sea_mask["lsm"][0, 0::2, 0::2] + + logging.info("Correcting masked longituds above 180.") + lon = sea_mask["longitude"][0::2] + lon[lon > 180] += -360 + loni = np.where(lon == lon.max())[0][0] + + logging.info("Maximum longitud in mask is: {}".format(lon.max())) + lat = sea_mask["latitude"][0::2] + LON, LAT = np.meshgrid(lon, lat) + + LON = np.concatenate((LON[:, loni + 1 :], LON[:, 0 : loni + 1]), axis=1) + mask = np.concatenate((mask[:, loni + 1 :], mask[:, 0 : loni + 1]), axis=1) + # positions where we are in the sea + logging.info(f"Generating sea mask with mask: {mask}") + possea = np.where(mask.ravel() == 0.0)[0] + if possea.size == 0: + logging.error( + f"No sea positions were found at the extracted mask from {mask_filepath}." + ) + return possea, LON, LAT + class BaseExtractor(ABC): file_var_key = "variable" file_key_key = "key" @@ -470,39 +509,21 @@ input_data: InputData, ref_dataset: Dataset, cases_dict: dict, - mask_dir: str, + mask_dir: Path, ) -> Tuple[List[int], List[int]]: - mask_filepath = os.path.join(mask_dir, "ERA5_landsea_mask.nc") - logging.info(f"Getting masked file from {mask_filepath}.") - with Dataset(mask_filepath, "r", self.netcdf_format) as dsmask: - logging.info(f"Extracting variable lsm from mask file.") - mask = dsmask["lsm"][0, 0::2, 0::2] - lon = dsmask["longitude"][0::2] - logging.info("Correcting masked longituds above 180.") - lon[lon > 180] += -360 - lat = dsmask["latitude"][0::2] - logging.info("Maximum longitud in mask is: {}".format(lon.max())) - loni = np.where(lon == lon.max())[0][0] - LON, LAT = np.meshgrid(lon, lat) - LON = np.concatenate((LON[:, loni + 1 :], LON[:, 0 : loni + 1]), axis=1) - mask = np.concatenate( - (mask[:, loni + 1 :], mask[:, 0 : loni + 1]), axis=1 - ) - # positions where we are in the sea - possea = np.where(mask.ravel() == 0.0)[0] - if possea.size == 0: - logging.error("No sea positions were found at the extracted mask.") - index, vals = self.get_nearest_neighbor_extended( - np.vstack((input_data._input_lon, input_data._input_lat)).T, - np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, - ) - index_abs = possea[index].squeeze() + mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" + possea, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) + index, vals = self.get_nearest_neighbor_extended( + np.vstack((input_data._input_lon, input_data._input_lat)).T, + np.vstack((LON.ravel()[possea], LAT.ravel()[possea])).T, + ) + index_abs = possea[index].squeeze() - if index_abs.size <= 0: - return None, None - indexes = np.unravel_index(index_abs, LON.shape) - unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) - return indexes[1], indexes[0] + if index_abs.size <= 0: + return None, None + indexes = np.unravel_index(index_abs, LON.shape) + unique_values, unique_idx = np.unique(vals, axis=0, return_index=True) + return indexes[1], indexes[0] @staticmethod def __get_case_subset(