Index: trunk/SDToolBox/statistical_model.py =================================================================== diff -u -r96 -r97 --- trunk/SDToolBox/statistical_model.py (.../statistical_model.py) (revision 96) +++ trunk/SDToolBox/statistical_model.py (.../statistical_model.py) (revision 97) @@ -48,9 +48,38 @@ random_state=0) @staticmethod - def cross_validation(*argv): - if(not argv or len(argv) != 4): - raise Exception() - clf = svm.SVC(kernel='linear', C=1).fit(argv[0], argv[2]) - print(clf.score(argv[1], argv[3])) + def cross_validation(train_test_data: Tuple[list, list, list, list]): + """Cross validates the train_test_data from the train_and_validate + function and returns the linear factor of its training data. + + Arguments: + train_test_data {Tuple[list, list, list, list]} + -- Lon_train, Lon_test, Lat_train, Lat_test + + Raises: + Exception: When the argument is not given as expected + + Returns: + [type] -- [description] + """ + if(not train_test_data or len(train_test_data) != 4): + raise Exception('Not all arguments given.') + clf = svm.SVC( + kernel='linear', + C=1 + ).fit( + train_test_data[0], + train_test_data[2]) + print(clf.score(train_test_data[1], train_test_data[3])) return clf + + @staticmethod + def time_series_split( + train_test_data: Tuple[list, list, list, list], + k_splits: int): + if(not train_test_data or len(train_test_data) != 4 + or not k_splits): + raise Exception('Not all arguments given.') + + tscv = TimeSeriesSplit(k_splits) + return tscv.split(train_test_data)