Commit | Line | Data |
---|---|---|
a961f8a1 BA |
1 | #' @include b_Algorithm.R |
2 | ||
3 | algoNameDictionary = list( | |
4 | ew = "ExponentialWeights", | |
5 | kn = "KnearestNeighbors", | |
6 | ga = "GeneralizedAdditive", | |
7 | ml = "MLpoly", | |
8 | rt = "RegressionTree", | |
9 | rr = "RidgeRegression", | |
10 | sv = "SVMclassif" | |
11 | ) | |
12 | ||
13 | #' @title Simulate real-time predict | |
14 | #' | |
15 | #' @description Run the algorithm coded by \code{shortAlgoName} on data specified by the \code{stations} argument. | |
16 | #' | |
17 | #' @param shortAlgoName Short name of the algorithm. | |
18 | #' \itemize{ | |
19 | #' \item ew : Exponential Weights | |
20 | #' \item ga : Generalized Additive Model | |
21 | #' \item kn : K Nearest Neighbors | |
22 | #' \item ml : MLpoly | |
23 | #' \item rt : Regression Tree | |
24 | #' \item rr : Ridge Regression | |
25 | #' } | |
26 | #' @param experts Vector of experts to consider (names or indices). Default: all of them. | |
27 | #' @param stations Vector of stations to consider (names or indices). Default: all of them. | |
28 | #' @param ... Additional arguments to be passed to the Algorithm object. | |
29 | #' | |
30 | #' @return A list with the following slots | |
31 | #' \itemize{ | |
32 | #' \item{data : data frame of all forecasts + measures (may contain NAs) + predictions, with date and station indices.} | |
33 | #' \item{algo : object of class \code{Algorithm} (or sub-class).} | |
34 | #' \item{experts : character vector of experts for this run.} | |
35 | #' \item{stations : character vector of stations for this run.} | |
36 | #' } | |
37 | #' | |
38 | #' @export | |
39 | runAlgorithm = function(shortAlgoName, experts=expertsArray, stations=stationsArray, ...) | |
40 | { | |
41 | #check, sanitize and format provided arguments | |
42 | if (! shortAlgoName %in% names(algoNameDictionary)) | |
43 | stop(paste("Typo in short algo name:", shortAlgoName)) | |
44 | if (!is.character(experts) && !is.numeric(experts)) | |
45 | stop("Wrong argument type: experts should be character or integer") | |
46 | if (!is.character(stations) && !is.numeric(stations)) | |
47 | stop("Wrong argument type: stations should be character or integer") | |
48 | experts = unique(experts) | |
49 | stations = unique(stations) | |
50 | Ka = length(expertsArray) | |
51 | Sa = length(stationsArray) | |
52 | if (length(experts) > Ka) | |
53 | stop("Too many experts specified: at least one of them does not exist") | |
54 | if (length(stations) > Sa) | |
55 | stop("Too many stations specified: at least one of them does not exist") | |
56 | if (is.numeric(experts) && any(experts > Ka)) | |
57 | stop(paste("Some experts indices are higher than the maximum which is", Ka)) | |
58 | if (is.numeric(stations) && any(stations > Sa)) | |
59 | stop(paste("Some stations indices are higher than the maximum which is", Sa)) | |
60 | if (is.character(experts)) | |
61 | { | |
62 | expertsMismatch = (1:Ka)[! experts %in% expertsArray] | |
63 | if (length(expertsMismatch) > 0) | |
64 | stop(cat(paste("Typo in experts names:", experts[expertsMismatch]), sep="\n")) | |
65 | } | |
66 | if (is.character(stations)) | |
67 | { | |
68 | stationsMismatch = (1:Sa)[! stations %in% stationsArray] | |
69 | if (length(stationsMismatch) > 0) | |
70 | stop(cat(paste("Typo in stations names:", stations[stationsMismatch]), sep="\n")) | |
71 | } | |
72 | if (!is.character(experts)) | |
73 | experts = expertsArray[experts] | |
74 | if (!is.character(stations)) | |
75 | stations = stationsArray[stations] | |
76 | ||
77 | #get data == ordered date indices + forecasts + measures + stations indices (would be DB in prod) | |
78 | oracleData = getData(experts, stations) | |
79 | ||
80 | #simulate incremental forecasts acquisition + prediction + get measure | |
81 | algoData = as.data.frame(matrix(nrow=0, ncol=ncol(oracleData))) | |
82 | names(algoData) = names(oracleData) | |
83 | algorithm = new(algoNameDictionary[[shortAlgoName]], data=algoData, ...) | |
84 | predictions = c() | |
85 | T = oracleData[nrow(oracleData),"Date"] | |
86 | for (t in 1:T) | |
87 | { | |
88 | #NOTE: bet that subset extract rows in the order they appear | |
89 | tData = subset(oracleData, subset = (Date==t)) | |
90 | algorithm$inputNextForecasts(tData[,names(tData) != "Measure"]) | |
91 | predictions = c(predictions, algorithm$predict_withNA()) | |
92 | algorithm$inputNextObservations(tData[,"Measure"]) | |
93 | } | |
94 | ||
95 | oracleData = cbind(oracleData, Prediction = predictions) | |
96 | return (list(data = oracleData, algo = algorithm, experts = experts, stations = stations)) | |
97 | } |