X-Git-Url: https://git.auder.net/%3C?a=blobdiff_plain;f=pkg%2FR%2Fz_runAlgorithm.R;fp=pkg%2FR%2Fz_runAlgorithm.R;h=f1c6ea46713e870ac001f9fc0321663c9183860b;hb=357b8be00388d07c04c10d4bf7f503fd947185d2;hp=0000000000000000000000000000000000000000;hpb=a961f8a15492bf4b18d24bc117358d1f412dd078;p=aggexp.git diff --git a/pkg/R/z_runAlgorithm.R b/pkg/R/z_runAlgorithm.R new file mode 100644 index 0000000..f1c6ea4 --- /dev/null +++ b/pkg/R/z_runAlgorithm.R @@ -0,0 +1,97 @@ +#' @include b_Algorithm.R + +algoNameDictionary = list( + ew = "ExponentialWeights", + kn = "KnearestNeighbors", + ga = "GeneralizedAdditive", + ml = "MLpoly", + rt = "RegressionTree", + rr = "RidgeRegression", + sv = "SVMclassif" +) + +#' @title Simulate real-time predict +#' +#' @description Run the algorithm coded by \code{shortAlgoName} on data specified by the \code{stations} argument. +#' +#' @param shortAlgoName Short name of the algorithm. +#' \itemize{ +#' \item ew : Exponential Weights +#' \item ga : Generalized Additive Model +#' \item kn : K Nearest Neighbors +#' \item ml : MLpoly +#' \item rt : Regression Tree +#' \item rr : Ridge Regression +#' } +#' @param experts Vector of experts to consider (names or indices). Default: all of them. +#' @param stations Vector of stations to consider (names or indices). Default: all of them. +#' @param ... Additional arguments to be passed to the Algorithm object. +#' +#' @return A list with the following slots +#' \itemize{ +#' \item{data : data frame of all forecasts + measures (may contain NAs) + predictions, with date and station indices.} +#' \item{algo : object of class \code{Algorithm} (or sub-class).} +#' \item{experts : character vector of experts for this run.} +#' \item{stations : character vector of stations for this run.} +#' } +#' +#' @export +runAlgorithm = function(shortAlgoName, experts=expertsArray, stations=stationsArray, ...) +{ + #check, sanitize and format provided arguments + if (! shortAlgoName %in% names(algoNameDictionary)) + stop(paste("Typo in short algo name:", shortAlgoName)) + if (!is.character(experts) && !is.numeric(experts)) + stop("Wrong argument type: experts should be character or integer") + if (!is.character(stations) && !is.numeric(stations)) + stop("Wrong argument type: stations should be character or integer") + experts = unique(experts) + stations = unique(stations) + Ka = length(expertsArray) + Sa = length(stationsArray) + if (length(experts) > Ka) + stop("Too many experts specified: at least one of them does not exist") + if (length(stations) > Sa) + stop("Too many stations specified: at least one of them does not exist") + if (is.numeric(experts) && any(experts > Ka)) + stop(paste("Some experts indices are higher than the maximum which is", Ka)) + if (is.numeric(stations) && any(stations > Sa)) + stop(paste("Some stations indices are higher than the maximum which is", Sa)) + if (is.character(experts)) + { + expertsMismatch = (1:Ka)[! experts %in% expertsArray] + if (length(expertsMismatch) > 0) + stop(cat(paste("Typo in experts names:", experts[expertsMismatch]), sep="\n")) + } + if (is.character(stations)) + { + stationsMismatch = (1:Sa)[! stations %in% stationsArray] + if (length(stationsMismatch) > 0) + stop(cat(paste("Typo in stations names:", stations[stationsMismatch]), sep="\n")) + } + if (!is.character(experts)) + experts = expertsArray[experts] + if (!is.character(stations)) + stations = stationsArray[stations] + + #get data == ordered date indices + forecasts + measures + stations indices (would be DB in prod) + oracleData = getData(experts, stations) + + #simulate incremental forecasts acquisition + prediction + get measure + algoData = as.data.frame(matrix(nrow=0, ncol=ncol(oracleData))) + names(algoData) = names(oracleData) + algorithm = new(algoNameDictionary[[shortAlgoName]], data=algoData, ...) + predictions = c() + T = oracleData[nrow(oracleData),"Date"] + for (t in 1:T) + { + #NOTE: bet that subset extract rows in the order they appear + tData = subset(oracleData, subset = (Date==t)) + algorithm$inputNextForecasts(tData[,names(tData) != "Measure"]) + predictions = c(predictions, algorithm$predict_withNA()) + algorithm$inputNextObservations(tData[,"Measure"]) + } + + oracleData = cbind(oracleData, Prediction = predictions) + return (list(data = oracleData, algo = algorithm, experts = experts, stations = stations)) +}