de_serialize works. Variables names improved. Code beautified. TODO: clustering tests
[epclust.git] / epclust / R / main.R
CommitLineData
56857861
BA
1#' @include utils.R
2#' @include clustering.R
3NULL
4
5#' Cluster power curves with PAM in parallel CLAWS: CLustering with wAvelets and Wer distanceS
7f0781b7 6#'
56857861 7#' Groups electricity power curves (or any series of similar nature) by applying PAM
cea14f3a 8#' algorithm in parallel to chunks of size \code{nb_series_per_chunk}
7f0781b7
BA
9#'
10#' @param data Access to the data, which can be of one of the three following types:
11#' \itemize{
12#' \item data.frame: each line contains its ID in the first cell, and all values after
13#' \item connection: any R connection object (e.g. a file) providing lines as described above
c33af7e4
BA
14#' \item function: a custom way to retrieve the curves; it has two arguments: the ranks to be
15#' retrieved, and the IDs - at least one of them must be present (priority: ranks).
7f0781b7 16#' }
1c6f223e
BA
17#' @param K1 Number of super-consumers to be found after stage 1 (K1 << N)
18#' @param K2 Number of clusters to be found after stage 2 (K2 << K1)
19#' @param ntasks Number of tasks (parallel iterations to obtain K1 medoids); default: 1.
20#' Note: ntasks << N, so that N is "roughly divisible" by N (number of series)
21#' @param nb_series_per_chunk (Maximum) number of series in each group, inside a task
cea14f3a 22#' @param min_series_per_chunk Minimum number of series in each group
3465b246 23#' @param wf Wavelet transform filter; see ?wt.filter. Default: haar
7f0781b7
BA
24#' @param WER "end" to apply stage 2 after stage 1 has iterated and finished, or "mix"
25#' to apply it after every stage 1
5c652979
BA
26#' @param ncores_tasks "MPI" number of parallel tasks (1 to disable: sequential tasks)
27#' @param ncores_clust "OpenMP" number of parallel clusterings in one task
28#' @param random Randomize chunks repartition
0e2dce80 29#' @param ... Other arguments to be passed to \code{data} function
7f0781b7
BA
30#'
31#' @return A data.frame of the final medoids curves (identifiers + values)
1c6f223e
BA
32#'
33#' @examples
34#' getData = function(start, n) {
35#' con = dbConnect(drv = RSQLite::SQLite(), dbname = "mydata.sqlite")
36#' df = dbGetQuery(con, paste(
37#' "SELECT * FROM times_values GROUP BY id OFFSET ",start,
38#' "LIMIT ", n, " ORDER BY date", sep=""))
39#' return (df)
40#' }
e205f218 41#' #####TODO: if DB, array rank --> ID at first retrieval, when computing coeffs; so:: NO use of IDs !
5c652979
BA
42#' #TODO: 3 examples, data.frame / binary file / DB sqLite
43#' + sampleCurves : wavBootstrap de package wmtsa
1c6f223e
BA
44#' cl = epclust(getData, K1=200, K2=15, ntasks=1000, nb_series_per_chunk=5000, WER="mix")
45#' @export
56857861
BA
46claws = function(getSeries, K1, K2,
47 random=TRUE, #randomize series order?
48 wf="haar", #stage 1
49 WER="end", #stage 2
50 ntasks=1, ncores_tasks=1, ncores_clust=4, #control parallelism
51 nb_series_per_chunk=50*K1, min_series_per_chunk=5*K1, #chunk size
52 sep=",", #ASCII input separator
53 nbytes=4, endian=.Platform$endian) #serialization (write,read)
ac1d4231 54{
0e2dce80 55 # Check/transform arguments
56857861
BA
56 if (!is.matrix(getSeries) && !is.function(getSeries) &&
57 !is(getSeries, "connection" && !is.character(getSeries)))
0e2dce80 58 {
56857861 59 stop("'getSeries': matrix, function, file or valid connection (no NA)")
5c652979 60 }
56857861
BA
61 K1 = .toInteger(K1, function(x) x>=2)
62 K2 = .toInteger(K2, function(x) x>=2)
63 if (!is.logical(random))
64 stop("'random': logical")
65 tryCatch(
66 {ignored <- wt.filter(wf)},
67 error = function(e) stop("Invalid wavelet filter; see ?wavelets::wt.filter"))
7f0781b7
BA
68 if (WER!="end" && WER!="mix")
69 stop("WER takes values in {'end','mix'}")
56857861
BA
70 ntasks = .toInteger(ntasks, function(x) x>=1)
71 ncores_tasks = .toInteger(ncores_tasks, function(x) x>=1)
72 ncores_clust = .toInteger(ncores_clust, function(x) x>=1)
73 nb_series_per_chunk = .toInteger(nb_series_per_chunk, function(x) x>=K1)
74 min_series_per_chunk = .toInteger(K1, function(x) x>=K1 && x<=nb_series_per_chunk)
75 if (!is.character(sep))
76 stop("'sep': character")
77 nbytes = .toInteger(nbytes, function(x) x==4 || x==8)
78
79 # Serialize series if required, to always use a function
80 bin_dir = "epclust.bin/"
81 dir.create(bin_dir, showWarnings=FALSE, mode="0755")
82 if (!is.function(getSeries))
83 {
84 series_file = paste(bin_dir,"data",sep="") ; unlink(series_file)
85 serialize(getSeries, series_file, nb_series_per_chunk, sep, nbytes, endian)
86 getSeries = function(indices) getDataInFile(indices, series_file, nbytes, endian)
87 }
ac1d4231 88
7b13d0c2 89 # Serialize all wavelets coefficients (+ IDs) onto a file
56857861 90 coefs_file = paste(bin_dir,"coefs",sep="") ; unlink(coefs_file)
7f0781b7 91 index = 1
cea14f3a 92 nb_curves = 0
6ecf5c2d 93 repeat
ac1d4231 94 {
0e2dce80
BA
95 series = getSeries((index-1)+seq_len(nb_series_per_chunk))
96 if (is.null(series))
cea14f3a 97 break
0e2dce80 98 coeffs_chunk = curvesToCoeffs(series, wf)
56857861 99 serialize(coeffs_chunk, coefs_file, nb_series_per_chunk, sep, nbytes, endian)
cea14f3a 100 index = index + nb_series_per_chunk
5c652979 101 nb_curves = nb_curves + nrow(coeffs_chunk)
8e6accca 102 }
56857861 103 getCoefs = function(indices) getDataInFile(indices, coefs_file, nbytes, endian)
8e6accca 104
5c652979
BA
105 if (nb_curves < min_series_per_chunk)
106 stop("Not enough data: less rows than min_series_per_chunk!")
107 nb_series_per_task = round(nb_curves / ntasks)
108 if (nb_series_per_task < min_series_per_chunk)
109 stop("Too many tasks: less series in one task than min_series_per_chunk!")
ac1d4231 110
7b13d0c2 111 # Cluster coefficients in parallel (by nb_series_per_chunk)
56857861 112 indices_all = if (random) sample(nb_curves) else seq_len(nb_curves)
48108c39 113 indices_tasks = lapply(seq_len(ntasks), function(i) {
5c652979 114 upper_bound = ifelse( i<ntasks, min(nb_series_per_task*i,nb_curves), nb_curves )
56857861 115 indices_all[((i-1)*nb_series_per_task+1):upper_bound]
48108c39 116 })
e205f218 117 cl = parallel::makeCluster(ncores_tasks)
56857861 118 # 1000*K1 indices [if WER=="end"], or empty vector [if WER=="mix"] --> series on file
e205f218 119 indices = unlist( parallel::parLapply(cl, indices_tasks, function(inds) {
56857861
BA
120 indices_medoids = clusteringTask(inds,getCoefs,K1,nb_series_per_chunk,ncores_clust)
121 if (WER=="mix")
122 {
123 medoids2 = computeClusters2(
124 getSeries(indices_medoids), K2, getSeries, nb_series_per_chunk)
125 serialize(medoids2, synchrones_file, nb_series_per_chunk, sep, nbytes, endian)
126 return (vector("integer",0))
127 }
128 indices_medoids
e205f218
BA
129 }) )
130 parallel::stopCluster(cl)
3465b246 131
e205f218 132 getSeriesForSynchrones = getSeries
56857861 133 synchrones_file = paste(bin_dir,"synchrones",sep="") ; unlink(synchrones_file)
e205f218
BA
134 if (WER=="mix")
135 {
136 indices = seq_len(ntasks*K2)
137 #Now series must be retrieved from synchrones_file
56857861 138 getSeries = function(inds) getDataInFile(inds, synchrones_file, nbytes, endian)
e205f218
BA
139 #Coefs must be re-computed
140 unlink(coefs_file)
141 index = 1
142 repeat
143 {
144 series = getSeries((index-1)+seq_len(nb_series_per_chunk))
145 if (is.null(series))
146 break
147 coeffs_chunk = curvesToCoeffs(series, wf)
56857861 148 serialize(coeffs_chunk, coefs_file, nb_series_per_chunk, sep, nbytes, endian)
e205f218
BA
149 index = index + nb_series_per_chunk
150 }
151 }
0e2dce80
BA
152
153 # Run step2 on resulting indices or series (from file)
56857861
BA
154 indices_medoids = clusteringTask(
155 indices, getCoefs, K1, nb_series_per_chunk, ncores_tasks*ncores_clust)
156 computeClusters2(getSeries(indices_medoids),K2,getSeriesForSynchrones,nb_series_per_chunk)
157}
158
159# helper
160curvesToCoeffs = function(series, wf)
161{
162 L = length(series[1,])
163 D = ceiling( log2(L) )
164 nb_sample_points = 2^D
165 apply(series, 1, function(x) {
166 interpolated_curve = spline(1:L, x, n=nb_sample_points)$y
167 W = wavelets::dwt(interpolated_curve, filter=wf, D)@W
168 rev( sapply( W, function(v) ( sqrt( sum(v^2) ) ) ) )
169 })
170}
171
172# helper
173.toInteger <- function(x, condition)
174{
175 if (!is.integer(x))
176 tryCatch(
177 {x = as.integer(x)[1]},
178 error = function(e) paste("Cannot convert argument",substitute(x),"to integer")
179 )
180 if (!condition(x))
181 stop(paste("Argument",substitute(x),"does not verify condition",body(condition)))
182 x
cea14f3a 183}