Fisheye: Tag 16 refers to a dead (removed) revision in file `trunk/tests/test_extract_waves_era5.py'. Fisheye: No comparison available. Pass `N' to diff? Index: trunk/tests/test_data_acquisition.py =================================================================== diff -u -r14 -r16 --- trunk/tests/test_data_acquisition.py (.../test_data_acquisition.py) (revision 14) +++ trunk/tests/test_data_acquisition.py (.../test_data_acquisition.py) (revision 16) @@ -8,15 +8,15 @@ import string from tests.TestUtils import TestUtils -from SDToolBox.data_acquisition import DataAcquisition +from SDToolBox.data_acquisition import InputData class Test_CreateDataAcquisition: @pytest.mark.unittest def test_when_create_empty_dataacquistion_no_exception_risen(self): try: - DataAcquisition() + InputData() except Exception as e_info: err_mssg = 'Error while creating DataAquisition object.' + \ '{}'.format(str(e_info)) @@ -27,7 +27,7 @@ @pytest.mark.unittest def test_when_no_dataset_then_returns_false(self): - data = DataAcquisition() + data = InputData() validation_result = data.validate(None) assert validation_result is False @@ -41,7 +41,7 @@ date_from = date_to + timedelta(days=1) # 2. When - validation_result = DataAcquisition.validate_input_dates( + validation_result = InputData.validate_input_dates( date_from, date_to) # 3. Then @@ -54,7 +54,7 @@ date_to = date_from + timedelta(days=1) # 2. When - validation_result = DataAcquisition.validate_input_dates( + validation_result = InputData.validate_input_dates( date_from, date_to) # 3. Then @@ -73,7 +73,7 @@ ) def test_when_no_coordinates_given_returns_false(self, coord_list): # 1. Given - data_acq = DataAcquisition() + data_acq = InputData() dataset = None # 2. When @@ -85,7 +85,7 @@ @pytest.mark.unittest def test_when_no_dataset_is_given_returns_false(self): # 1. Given - data_acq = DataAcquisition() + data_acq = InputData() data_acq.coord_list = [(2.4, 4.2)] dataset = None @@ -107,7 +107,7 @@ def test_when_dataset_without_keys_given_returns_false( self, lon_key: str, lat_key: str): # 1. Given - data_acq = DataAcquisition() + data_acq = InputData() data_acq.lon_key = lon_key data_acq.lat_key = lat_key data_acq.coord_list = [(2.4, 4.2)] @@ -134,7 +134,7 @@ 'given_one_coordinate.nc' # 1. Given - data_acq = DataAcquisition() + data_acq = InputData() data_acq.lon_key = 'Longitude' data_acq.lat_key = 'Latitude' data_acq.coord_list = [(2.4, 4.2)] @@ -162,7 +162,7 @@ Dataset -- NetCDF4 Dataset """ # 1. Create some data. - lon = np.arange(45, 101, 2) + lon = np.arange(0, 56, 2) lat = np.arange(-30, 25, 2.5) z = np.arange(0, 200, 10) x = np.random.randint(10, 25, size=(len(lon), len(lat), len(z))) Index: trunk/tests/test_extract_waves_era5.py =================================================================== diff -u -r15 -r16 --- trunk/tests/test_extract_waves_era5.py (.../test_extract_waves_era5.py) (revision 15) +++ trunk/tests/test_extract_waves_era5.py (.../test_extract_waves.py) (revision 16) @@ -1,98 +1,141 @@ -import unittest, pytest +import pytest import os from os import path import netCDF4 from tests.TestUtils import TestUtils + from SDToolBox import main as main -from SDToolBox import extract_waves_era5 +from SDToolBox import extract_waves +from SDToolBox import data_acquisition import numpy as np -class Test_ExtractWavesEra5: - """ - Creates a netCFD4 file stores it in the local test data dir - It then checks it can correctly read it - """ + +class Test_create: + @pytest.mark.unittest - def test_create_and_read_dummy_netCDF_data(self): - #dummy test to create and read netCDF test file - keys_list = [] - #expected_keys_list = dict(lon = 3, lat = 3, time = None) - x = np.array([[1,2,3],[4,5,6]], np.int32) + def test_given_no_input_data_then_exception_is_risen(self): + # 1. Given + input_data = None + output_result = None + expected_error = 'No valid input data.' + # 2. When + with pytest.raises(IOError) as e_info: + output_result = extract_waves.ExtractWaves(input_data) - data_dir= TestUtils.get_local_test_data_dir("netCDF_dummy_data") - data_file = data_dir + "era5_Global_Hs_1980.nc" + # 3. Then + error_message = str(e_info.value) + assert output_result is None + assert error_message == expected_error, '' + \ + 'Expected exception message {},'.format(expected_error) + \ + 'retrieved {}'.format(error_message) + + +class Test_subset_era_5: + + @pytest.mark.systemtest + def test_given_waves_folder_then_subset_collection_is_extracted(self): + # 1. Given + # When using local data you can just replace the comment in these lines + #dir_test_data = TestUtils.get_test_data_dir('netCDF_Waves_data') + dir_test_data = 'P:\\metocean-data\\open\\ERA5\\data\\Global' + input_data = data_acquisition.InputData() + input_data.coord_list = [(4.2, 2.4)] + + # 2. When + extract_wave = extract_waves.ExtractWaves(input_data) + dataset_list = extract_wave.subset_era_5(dir_test_data, 1980, 1982) + + # 3. Then + assert dataset_list is not None + + +# class Test_ExtractWavesEra5: +# """ +# Creates a netCFD4 file stores it in the local test data dir +# It then checks it can correctly read it +# """ +# @pytest.mark.unittest +# def test_create_and_read_dummy_netCDF_data(self): +# # dummy test to create and read netCDF test file +# keys_list = [] +# # expected_keys_list = dict(lon = 3, lat = 3, time = None) +# x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) + +# data_dir= TestUtils.get_local_test_data_dir("netCDF_dummy_data") +# data_file = data_dir + "era5_Global_Hs_1980.nc" - if not path.exists(data_file) : - f = netCDF4.Dataset(data_file,'w', format='NETCDF4') - else : - f = netCDF4.Dataset(data_file,'r', format='NETCDF4') +# if not path.exists(data_file): +# f = netCDF4.Dataset(data_file,'w', format='NETCDF4') +# else: +# f = netCDF4.Dataset(data_file,'r', format='NETCDF4') - tempgrp = f.createGroup('Temp_data') - tempgrp = f.createGroup('Temp_data') - tempgrp.createDimension('lon', len(x[0])) - tempgrp.createDimension('lat', len(x[1])) - tempgrp.createDimension('time', None) - f.close() - f = netCDF4.Dataset(data_file,'r') - tempgrp = f.groups['Temp_data'] - keys_list=tempgrp.dimensions.keys() - f.close() +# tempgrp = f.createGroup('Temp_data') +# tempgrp = f.createGroup('Temp_data') +# tempgrp.createDimension('lon', len(x[0])) +# tempgrp.createDimension('lat', len(x[1])) +# tempgrp.createDimension('time', None) +# f.close() +# f = Dataset(data_file, 'r') +# tempgrp = f.groups['Temp_data'] +# keys_list=tempgrp.dimensions.keys() +# f.close() - # assert expected_keys_list[0] == keys_list[0] - # assert expected_keys_list[1] == keys_list[1] - #assert expected_keys_list[2] == keys_list[2] - """ - Instantiates a wave extraction class and checks that - it is correctly created - """ - @pytest.mark.unittest - def test_instantiating_extract_wave_era5_returns_allocated_object(self): - waves_era5 = extract_waves_era5.ExtractWavesEra5(0.0,0.0,"path", 1980,2000) - assert waves_era5.lat == 0.0 - assert waves_era5.lon == 0.0 - assert waves_era5.dpath == "path" - assert waves_era5.year1 == 1980 - assert waves_era5.yearN == 2000 - """ - Checks that the longitude is normalized if a value higher than 180 is passed - """ - @pytest.mark.unittest - def test_given_longitude_higher_than_180_returns_normalized_longitude(self): - #setup - longitude = 400 - normalized_longitude = longitude -180 - waves_era5 = extract_waves_era5.ExtractWavesEra5(0.0,0.0,"path", 1980,2000) - #call - result = waves_era5.check_for_longitude(longitude) - #assert - assert result == normalized_longitude - """ - Checks that the longitude is unchanged if a value lower than 180 is passed - """ - @pytest.mark.unittest - def test_given_longitude_lower_than_180_returns_normalized_unchanged(self): - #setup - longitude = 30 - waves_era5 = extract_waves_era5.ExtractWavesEra5(0.0,0.0,"path", 1980,2000) - #call - result = waves_era5.check_for_longitude(longitude) - #assert - assert result == 30 - """ - Checks that array of years if correctly generated - """ - @pytest.mark.unittest - def test_years_array_is_correctly_generated(self): - #setup - year1 = 1980 - yearN = 1983 - result_array = [1980,1981,1982,1983] - result = [] - waves_era5 = extract_waves_era5.ExtractWavesEra5(0.0,0.0,"path", 1980,2000) - #call - result = waves_era5.generate_years_array(year1, yearN) - #assert - assert result[0] == result_array[0] - assert result[1] == result_array[1] - assert result[2] == result_array[2] - assert result[3] == result_array[3] +# # assert expected_keys_list[0] == keys_list[0] +# # assert expected_keys_list[1] == keys_list[1] +# # assert expected_keys_list[2] == keys_list[2] + +# """ +# Instantiates a wave extraction class and checks that +# it is correctly created +# """ +# @pytest.mark.unittest +# def test_instantiating_extract_wave_era5_returns_allocated_object(self): +# waves_era5 = extract_waves.ExtractWaves(0.0,0.0,"path", 1980,2000) +# assert waves_era5.lat == 0.0 +# assert waves_era5.lon == 0.0 +# assert waves_era5.dpath == "path" +# assert waves_era5.year1 == 1980 +# assert waves_era5.yearN == 2000 +# """ +# Checks that the longitude is normalized if a value higher than 180 is passed +# """ +# @pytest.mark.unittest +# def test_given_longitude_higher_than_180_returns_normalized_longitude(self): +# #setup +# longitude = 400 +# normalized_longitude = longitude -180 +# waves_era5 = extract_waves_era.subset_era_5(0.0,0.0,"path", 1980,2000) +# #call +# result = waves_era5.check_for_longitude(longitude) +# #assert +# assert result == normalized_longitude +# """ +# Checks that the longitude is unchanged if a value lower than 180 is passed +# """ +# @pytest.mark.unittest +# def test_given_longitude_lower_than_180_returns_normalized_unchanged(self): +# #setup +# longitude = 30 +# waves_era5 = extract_waves_era.subset_era_5(0.0,0.0,"path", 1980,2000) +# #call +# result = waves_era5.check_for_longitude(longitude) +# #assert +# assert result == 30 +# """ +# Checks that array of years if correctly generated +# """ +# @pytest.mark.unittest +# def test_years_array_is_correctly_generated(self): +# #setup +# year1 = 1980 +# yearN = 1983 +# result_array = [1980,1981,1982,1983] +# result = [] +# waves_era5 = extract_waves_era5.ExtractWavesEra5(0.0,0.0,"path", 1980,2000) +# #call +# result = waves_era5.generate_years_array(year1, yearN) +# #assert +# assert result[0] == result_array[0] +# assert result[1] == result_array[1] +# assert result[2] == result_array[2] +# assert result[3] == result_array[3] \ No newline at end of file Fisheye: Tag 16 refers to a dead (removed) revision in file `trunk/SDToolBox/extract_waves_era5.py'. Fisheye: No comparison available. Pass `N' to diff? Index: trunk/SDToolBox/data_acquisition.py =================================================================== diff -u -r14 -r16 --- trunk/SDToolBox/data_acquisition.py (.../data_acquisition.py) (revision 14) +++ trunk/SDToolBox/data_acquisition.py (.../data_acquisition.py) (revision 16) @@ -1,14 +1,31 @@ #! /usr/bin/env python from datetime import datetime + from netCDF4 import Dataset + +from sklearn.neighbors import BallTree as BallTree + +import xarray as xr import numpy as np -class DataAcquisition: +def get_nearest_neighbor(value, data_array): + """ + search for nearest decimal degree in an array of decimal degrees + and return the index. + np.argmin returns the indices of minium value along an axis. + so subtract value from all values in data_array, take absolute value + and find index of minium. + """ + return (np.abs(data_array - value)).argmin() + +class InputData: + coord_list = [] date_from = None date_to = None + # These parameters are set in the extraction methods min_longitude = None max_longitude = None lon_key = 'lon' @@ -60,32 +77,40 @@ "longitude: {}".format(self.lon_key)) return False + xarray = xr.DataArray(dataset) lat_list = dataset.variables[self.lat_key][:] lon_list = dataset.variables[self.lon_key][:] # validate grid within points, otherwise create bounding box. if len(self.coord_list) == 1: # Only one point, meaning we need to retrieve its closest. - lat, lon = self.coord_list[0] - self.coord_list.clear() - lat_approx = self.__geo_idx(lat, lat_list) - lon_approx = self.__geo_idx(lon, lon_list) - self.coord_list.append((lat_approx, lon_approx)) + lon, lat = self.coord_list[0] + # near_coord = self.__nearest_neighbors(dataset, self.coord_list[0]) + # self.coord_list.clear() + lat_approx = get_nearest_neighbor(lat, lat_list) + lon_approx = get_nearest_neighbor(lon, lon_list) + + self.coord_list.append((lon, lat)) return True return False - @staticmethod - def __geo_idx(dd, dd_array): - """ - search for nearest decimal degree in an array of decimal degrees - and return the index. - np.argmin returns the indices of minium value along an axis. - so subtract dd from all values in dd_array, take absolute value - and find index of minium. - """ - geo_idx = (np.abs(dd_array - dd)).argmin() - return geo_idx + def __nearest_neighbors(self, dataset, coord): + # lon = x, lat = y + x, y = coord + dataset_lon = dataset[self.lon_key][:] + dataset_lat = dataset[self.lat_key][:] + # lon[lon>180]=lon[lon>180]-360 + grid_lon, grid_lat = np.meshgrid(dataset_lon, dataset_lat) + kd = BallTree( + np.column_stack( + (grid_lon.ravel(), grid_lat.ravel())), + leaf_size=10) + _, indx = kd.query(np.column_stack((x, y))) + pos_lon = np.where(dataset_lon == grid_lon.ravel()[indx].squeeze())[0] + pos_lat = np.where(dataset_lat == grid_lat.ravel()[indx].squeeze())[0] + return (pos_lon, pos_lat) + @staticmethod def validate_input_dates(date_from: datetime, date_to: datetime): """Validates the given dates as a range. @@ -106,3 +131,4 @@ print("Input date from needs cannot be later than the date to.") return False return True + \ No newline at end of file Index: trunk/setup.py =================================================================== diff -u -r9 -r16 --- trunk/setup.py (.../setup.py) (revision 9) +++ trunk/setup.py (.../setup.py) (revision 16) @@ -18,5 +18,9 @@ install_requires=[ "netCDF4>=1.2.9", "numpy>=1.12", + "pandas>=0.19.2" + "scipy>=1.0", + "scikit_learn>=0.19.0", + "xarray>=0.13.0" ], ) \ No newline at end of file Index: trunk/SDToolBox/extract_waves.py =================================================================== diff -u --- trunk/SDToolBox/extract_waves.py (revision 0) +++ trunk/SDToolBox/extract_waves.py (revision 16) @@ -0,0 +1,114 @@ +#! /usr/bin/env python +""" + +""" + +# region // imports +import sys +import os + +from SDToolBox import outputmessages as outputmessages +from SDToolBox import data_acquisition +from netCDF4 import Dataset +import numpy as np + +# endregion + +# region // variables + +# endregion + + +class ExtractWaves: + + __lon_key = 'longitude' + __lat_key = 'latitude' + __input_lon = None + __input_lat = None + __input_data = None + + def __init__(self, input_data: data_acquisition.InputData): + """Initialize the waves extraction. + + Arguments: + input_data {data_acquisition.InputData} -- Required. + """ + # verify input_data not none + if not input_data or \ + not input_data.coord_list or \ + len(input_data.coord_list) < 1: + raise IOError('No valid input data.') + + input_data.lon_key = self.__lon_key + input_data.lat_key = self.__lat_key + self.__input_data = input_data + self.__input_lon, self.__input_lat = input_data.coord_list[0] + + def subset_era_5(self, directory_path: str, year_from: int, year_to: int): + """Extracts a collection of netCDF4 subsets based on the input data + set when creating the extract_waves object. + + Arguments: + directory_path {str} -- Location of all the variable diretories. + year_from {int} -- Start of time data to substract. + year_to {int} -- End of time data to substract. + + Returns: + list(Dataset) -- collection of netCDF4 subsets per variable. + """ + variable_dict = { + 'swh': 'Hs', + 'pp1d': 'Tp', + 'mwd': 'MWD', + 'mwp': 'Tm', + } + first_key = 'swh' + # test_data/HS/era5_Global_HS_year + + # longitude should be found as the 'x' in the first coordinate of + self.__input_lon = self.check_for_longitude(self.__input_lon) + years = self.generate_years_array(year_from, year_to) + case_dataset_list = [] + for variable_name in variable_dict: + case_name_value = variable_dict[variable_name] + for year_idx, year in enumerate(years): + print(year, '-', variable_name) + base_file_name = 'era5_Global_{}_{}.nc'.format(case_name_value, year) + case_dir = directory_path + case_name_value + case_file_path = os.path.join(case_dir, base_file_name) + + # If file does not exist simply go to the next one + if not os.path.exists(case_file_path): + print('File {} does not exist or could not be found.'.format(case_file_path)) + continue + + # Open the current Dataset + with Dataset(case_file_path, 'r', format='netCDF4') as case_dataset: + # Get dataset + # Find nearest point (considering we are only selecting a point) + lat_list = case_dataset.variables[self.__lat_key][:] + lon_list = case_dataset.variables[self.__lon_key][:] + corr_lon = data_acquisition.get_nearest_neighbor(self.__input_lon, lon_list) + corr_lat = data_acquisition.get_nearest_neighbor(self.__input_lat, lat_list) + # Get variable subset + variable_subset = case_dataset.variables[variable_name][:corr_lon:corr_lat] + case_dataset_list.append(variable_subset) + + # Previous subsets should already be transformed into xarray + # The collection of datasets should be stored into a unique netcdf (check related issue). + + return case_dataset_list + + @staticmethod + def check_for_longitude(longitude): + if longitude > 180: + return longitude-180 + return longitude + + @staticmethod + def generate_years_array(year_from, year_to): + years = [] + for i in range(year_to - year_from): + years.append(year_from + i) # fills an array of years + years.append(year_to) + return years