rename pkg --> aggexp
[aggexp.git] / aggexp / R / z_runAlgorithm.R
diff --git a/aggexp/R/z_runAlgorithm.R b/aggexp/R/z_runAlgorithm.R
new file mode 100644 (file)
index 0000000..f1c6ea4
--- /dev/null
@@ -0,0 +1,97 @@
+#' @include b_Algorithm.R
+
+algoNameDictionary = list(
+       ew = "ExponentialWeights",
+       kn = "KnearestNeighbors",
+       ga = "GeneralizedAdditive",
+       ml = "MLpoly",
+       rt = "RegressionTree",
+       rr = "RidgeRegression",
+       sv = "SVMclassif"
+)
+
+#' @title Simulate real-time predict
+#'
+#' @description Run the algorithm coded by \code{shortAlgoName} on data specified by the \code{stations} argument.
+#'
+#' @param shortAlgoName Short name of the algorithm.
+#' \itemize{
+#'   \item ew : Exponential Weights
+#'   \item ga : Generalized Additive Model
+#'   \item kn : K Nearest Neighbors
+#'   \item ml : MLpoly
+#'   \item rt : Regression Tree
+#'   \item rr : Ridge Regression
+#' }
+#' @param experts Vector of experts to consider (names or indices). Default: all of them.
+#' @param stations Vector of stations to consider (names or indices). Default: all of them.
+#' @param ... Additional arguments to be passed to the Algorithm object.
+#'
+#' @return A list with the following slots
+#' \itemize{
+#'   \item{data : data frame of all forecasts + measures (may contain NAs) + predictions, with date and station indices.}
+#'   \item{algo : object of class \code{Algorithm} (or sub-class).}
+#'   \item{experts : character vector of experts for this run.}
+#'   \item{stations : character vector of stations for this run.}
+#' }
+#'
+#' @export
+runAlgorithm = function(shortAlgoName, experts=expertsArray, stations=stationsArray, ...)
+{
+       #check, sanitize and format provided arguments
+       if (! shortAlgoName %in% names(algoNameDictionary))
+               stop(paste("Typo in short algo name:", shortAlgoName))
+       if (!is.character(experts) && !is.numeric(experts))
+               stop("Wrong argument type: experts should be character or integer")
+       if (!is.character(stations) && !is.numeric(stations))
+               stop("Wrong argument type: stations should be character or integer")
+       experts = unique(experts)
+       stations = unique(stations)
+       Ka = length(expertsArray)
+       Sa = length(stationsArray)
+       if (length(experts) > Ka)
+               stop("Too many experts specified: at least one of them does not exist")
+       if (length(stations) > Sa)
+               stop("Too many stations specified: at least one of them does not exist")
+       if (is.numeric(experts) && any(experts > Ka))
+               stop(paste("Some experts indices are higher than the maximum which is", Ka))
+       if (is.numeric(stations) && any(stations > Sa))
+               stop(paste("Some stations indices are higher than the maximum which is", Sa))
+       if (is.character(experts))
+       {
+               expertsMismatch = (1:Ka)[! experts %in% expertsArray]
+               if (length(expertsMismatch) > 0)
+                       stop(cat(paste("Typo in experts names:", experts[expertsMismatch]), sep="\n"))
+       }
+       if (is.character(stations))
+       {
+               stationsMismatch = (1:Sa)[! stations %in% stationsArray]
+               if (length(stationsMismatch) > 0)
+                       stop(cat(paste("Typo in stations names:", stations[stationsMismatch]), sep="\n"))
+       }
+       if (!is.character(experts))
+               experts = expertsArray[experts]
+       if (!is.character(stations))
+               stations = stationsArray[stations]
+
+       #get data == ordered date indices + forecasts + measures + stations indices (would be DB in prod)
+       oracleData = getData(experts, stations)
+
+       #simulate incremental forecasts acquisition + prediction + get measure
+       algoData = as.data.frame(matrix(nrow=0, ncol=ncol(oracleData)))
+       names(algoData) = names(oracleData)
+       algorithm = new(algoNameDictionary[[shortAlgoName]], data=algoData, ...)
+       predictions = c()
+       T = oracleData[nrow(oracleData),"Date"]
+       for (t in 1:T)
+       {
+               #NOTE: bet that subset extract rows in the order they appear
+               tData = subset(oracleData, subset = (Date==t))
+               algorithm$inputNextForecasts(tData[,names(tData) != "Measure"])
+               predictions = c(predictions, algorithm$predict_withNA())
+               algorithm$inputNextObservations(tData[,"Measure"])
+       }
+
+       oracleData = cbind(oracleData, Prediction = predictions)
+       return (list(data = oracleData, algo = algorithm, experts = experts, stations = stations))
+}