Index: trunk/tests/test_acceptance.py =================================================================== diff -u -r192 -r193 --- trunk/tests/test_acceptance.py (.../test_acceptance.py) (revision 192) +++ trunk/tests/test_acceptance.py (.../test_acceptance.py) (revision 193) @@ -202,9 +202,10 @@ steplat = 0.25 latl = 52.5 latu = 57 - x_range = np.arange(lonl, lonr + steplon, steplon).tolist() - y_range = np.arange(latl, latu + steplat, steplat).tolist() - coordsBOX = {"LON": x_range, "LAT": y_range} + coordsBOX = { + "LON": np.arange(lonl, lonr + steplon, steplon).tolist(), + "LAT": np.arange(latl, latu + steplat, steplat).tolist(), + } timeWAMy = list(range(1981, 1987)) # use the SDToolBox function to create input data Input_DataBOX = InputData( @@ -221,6 +222,36 @@ # assert extracted_data.variables["msl"].any(), "No values generated for key msl." ChunkDataUtils.assert_valid_ndarray(extracted_data.variables["msl"]) + @pytest.mark.parametrize( + "variable_key", [pytest.param("swh"), pytest.param("msl"),], + ) + def test_given_gridded_test_case_lat_lon_returned_correctly( + self, variable_key: str + ): + steplon = 0.5 + lonl = 3 + lonr = 7.5 + steplat = 0.5 + latl = 52.5 + latu = 57 + coordsBOX = { + "LON": np.arange(lonl, lonr + steplon, steplon).tolist(), + "LAT": np.arange(latl, latu + steplat, steplat).tolist(), + } + timeWAMy = list(range(1981, 1987)) + Input_DataBOX = InputData( + input_coordinates=coordsBOX, + input_variables=[variable_key], + input_scenarios=["era5"], + input_years=timeWAMy, + is_gridded=True, + ) + + dir_test_data = TestUtils.get_local_test_data_dir("chunked_data") + extracted_data = ExtractData.get_era_5(dir_test_data, Input_DataBOX) + assert extracted_data.data_dict["lat"] == coordsBOX["LAT"] + assert extracted_data.data_dict["lon"] == coordsBOX["LON"] + def test_extract_gridded_data_multivariable(self): chunked_lon, chunked_lat = ChunkDataUtils.get_default_chunked_lon_lat() coordsBOX = {"LON": chunked_lon, "LAT": chunked_lat} Index: trunk/SDToolBox/extract_data.py =================================================================== diff -u -r192 -r193 --- trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 192) +++ trunk/SDToolBox/extract_data.py (.../extract_data.py) (revision 193) @@ -245,22 +245,22 @@ # Set the nearest neighbors and the time reference. logging.info("Getting nearest neighbors.") - nn_idx, nn_values = self.__get_nearest_neighbors_lon_lat( + nn_idx, nn_values = self.__get_nearest_neighbors_lat_lon( ref_file_path=self.__first_filepath, input_data=input_data, cases_dict=output_data.data_dict, ) logging.info( - f"Nearest neighbors: \nLon:{nn_values[0].astype(str)} \nLat:{nn_values[1].astype(str)}" + f"Nearest neighbors: \nLat:{nn_values[0].astype(str)} \nLon:{nn_values[1].astype(str)}" ) nn_idx, idx_map = self.__get_unique_indices(nn_idx) - output_data.data_dict[output_data.var_lon_key] = [ + output_data.data_dict[output_data.var_lat_key] = [ value for idx, value in enumerate(nn_values[0].tolist()) if idx in idx_map ] - output_data.data_dict[output_data.var_lat_key] = [ + output_data.data_dict[output_data.var_lon_key] = [ value for idx, value in enumerate(nn_values[1].tolist()) if idx in idx_map @@ -379,7 +379,7 @@ axis=int(not input_data.is_gridded), ) - def __get_nearest_neighbors_lon_lat( + def __get_nearest_neighbors_lat_lon( self, ref_file_path: str, input_data: InputData, cases_dict: dict ) -> Tuple[Tuple[List[int], List[int]], Tuple[np.array, np.array]]: """Gets the corrected index and value for the given @@ -441,8 +441,8 @@ is_gridded=input_data.is_gridded, ) return ( - (nn_lon_idx, nn_lat_idx), - (np.array(nn_lon_values), np.array(nn_lat_values)), + (nn_lat_idx, nn_lon_idx), + (np.array(nn_lat_values), np.array(nn_lon_values)), ) def _set_nn( @@ -508,6 +508,17 @@ def get_masked_nn_lat_lon( self, input_data: InputData, ref_dataset: Dataset, mask_dir: Path, ) -> Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: + """Returns the masked latitude and longiuted of a givendataset. + + Args: + input_data (InputData): Input data with coordinates. + ref_dataset (Dataset): Dataset that needs to be masked. + mask_dir (Path): Direction to the mask directory. + + Returns: + Tuple[Tuple[List[int], List[int]], Tuple[List[float], List[float]]]: + Index of nearest neighbors and a Tuple of corrected Latitude and Longitude + """ mask_filepath = Path(mask_dir) / "ERA5_landsea_mask.nc" possea, LON, LAT = ExtractData.get_seamask_positions(mask_filepath) index, vals = self.get_nearest_neighbor_extended(