Fix unit tests
[epclust.git] / epclust / R / main.R
CommitLineData
8702eb86 1#' @include de_serialize.R
56857861
BA
2#' @include clustering.R
3NULL
4
8702eb86 5#' CLAWS: CLustering with wAvelets and Wer distanceS
7f0781b7 6#'
56857861 7#' Groups electricity power curves (or any series of similar nature) by applying PAM
8702eb86
BA
8#' algorithm in parallel to chunks of size \code{nb_series_per_chunk}. Input series
9#' must be sampled on the same time grid, no missing values.
7f0781b7 10#'
8702eb86
BA
11#' @param getSeries Access to the (time-)series, which can be of one of the three
12#' following types:
13#' \itemize{
14#' \item matrix: each line contains all the values for one time-serie, ordered by time
15#' \item connection: any R connection object (e.g. a file) providing lines as described above
16#' \item function: a custom way to retrieve the curves; it has only one argument:
17#' the indices of the series to be retrieved. See examples
18#' }
1c6f223e
BA
19#' @param K1 Number of super-consumers to be found after stage 1 (K1 << N)
20#' @param K2 Number of clusters to be found after stage 2 (K2 << K1)
8702eb86
BA
21#' @param random TRUE (default) for random chunks repartition
22#' @param wf Wavelet transform filter; see ?wavelets::wt.filter. Default: haar
23#' @param WER "end" to apply stage 2 after stage 1 has fully iterated, or "mix" to apply stage 2
24#' at the end of each task
1c6f223e
BA
25#' @param ntasks Number of tasks (parallel iterations to obtain K1 medoids); default: 1.
26#' Note: ntasks << N, so that N is "roughly divisible" by N (number of series)
5c652979
BA
27#' @param ncores_tasks "MPI" number of parallel tasks (1 to disable: sequential tasks)
28#' @param ncores_clust "OpenMP" number of parallel clusterings in one task
8702eb86
BA
29#' @param nb_series_per_chunk (~Maximum) number of series in each group, inside a task
30#' @param min_series_per_chunk Minimum number of series in each group
31#' @param sep Separator in CSV input file (relevant only if getSeries is a file name)
32#' @param nbytes Number of bytes to serialize a floating-point number; 4 or 8
33#' @param endian Endianness to use for (de)serialization. Use "little" or "big" for portability
7f0781b7 34#'
8702eb86 35#' @return A matrix of the final medoids curves
1c6f223e
BA
36#'
37#' @examples
38#' getData = function(start, n) {
39#' con = dbConnect(drv = RSQLite::SQLite(), dbname = "mydata.sqlite")
40#' df = dbGetQuery(con, paste(
41#' "SELECT * FROM times_values GROUP BY id OFFSET ",start,
42#' "LIMIT ", n, " ORDER BY date", sep=""))
43#' return (df)
44#' }
e205f218 45#' #####TODO: if DB, array rank --> ID at first retrieval, when computing coeffs; so:: NO use of IDs !
5c652979
BA
46#' #TODO: 3 examples, data.frame / binary file / DB sqLite
47#' + sampleCurves : wavBootstrap de package wmtsa
1c6f223e
BA
48#' cl = epclust(getData, K1=200, K2=15, ntasks=1000, nb_series_per_chunk=5000, WER="mix")
49#' @export
56857861
BA
50claws = function(getSeries, K1, K2,
51 random=TRUE, #randomize series order?
52 wf="haar", #stage 1
53 WER="end", #stage 2
54 ntasks=1, ncores_tasks=1, ncores_clust=4, #control parallelism
55 nb_series_per_chunk=50*K1, min_series_per_chunk=5*K1, #chunk size
56 sep=",", #ASCII input separator
57 nbytes=4, endian=.Platform$endian) #serialization (write,read)
ac1d4231 58{
0e2dce80 59 # Check/transform arguments
56857861
BA
60 if (!is.matrix(getSeries) && !is.function(getSeries) &&
61 !is(getSeries, "connection" && !is.character(getSeries)))
0e2dce80 62 {
56857861 63 stop("'getSeries': matrix, function, file or valid connection (no NA)")
5c652979 64 }
56857861
BA
65 K1 = .toInteger(K1, function(x) x>=2)
66 K2 = .toInteger(K2, function(x) x>=2)
67 if (!is.logical(random))
68 stop("'random': logical")
69 tryCatch(
70 {ignored <- wt.filter(wf)},
71 error = function(e) stop("Invalid wavelet filter; see ?wavelets::wt.filter"))
7f0781b7
BA
72 if (WER!="end" && WER!="mix")
73 stop("WER takes values in {'end','mix'}")
56857861
BA
74 ntasks = .toInteger(ntasks, function(x) x>=1)
75 ncores_tasks = .toInteger(ncores_tasks, function(x) x>=1)
76 ncores_clust = .toInteger(ncores_clust, function(x) x>=1)
77 nb_series_per_chunk = .toInteger(nb_series_per_chunk, function(x) x>=K1)
78 min_series_per_chunk = .toInteger(K1, function(x) x>=K1 && x<=nb_series_per_chunk)
79 if (!is.character(sep))
80 stop("'sep': character")
81 nbytes = .toInteger(nbytes, function(x) x==4 || x==8)
82
83 # Serialize series if required, to always use a function
84 bin_dir = "epclust.bin/"
85 dir.create(bin_dir, showWarnings=FALSE, mode="0755")
86 if (!is.function(getSeries))
87 {
88 series_file = paste(bin_dir,"data",sep="") ; unlink(series_file)
89 serialize(getSeries, series_file, nb_series_per_chunk, sep, nbytes, endian)
90 getSeries = function(indices) getDataInFile(indices, series_file, nbytes, endian)
91 }
ac1d4231 92
7b13d0c2 93 # Serialize all wavelets coefficients (+ IDs) onto a file
56857861 94 coefs_file = paste(bin_dir,"coefs",sep="") ; unlink(coefs_file)
7f0781b7 95 index = 1
cea14f3a 96 nb_curves = 0
6ecf5c2d 97 repeat
ac1d4231 98 {
0e2dce80
BA
99 series = getSeries((index-1)+seq_len(nb_series_per_chunk))
100 if (is.null(series))
cea14f3a 101 break
8702eb86
BA
102 coefs_chunk = curvesToCoefs(series, wf)
103 serialize(coefs_chunk, coefs_file, nb_series_per_chunk, sep, nbytes, endian)
cea14f3a 104 index = index + nb_series_per_chunk
8702eb86 105 nb_curves = nb_curves + nrow(coefs_chunk)
8e6accca 106 }
56857861 107 getCoefs = function(indices) getDataInFile(indices, coefs_file, nbytes, endian)
8e6accca 108
5c652979
BA
109 if (nb_curves < min_series_per_chunk)
110 stop("Not enough data: less rows than min_series_per_chunk!")
111 nb_series_per_task = round(nb_curves / ntasks)
112 if (nb_series_per_task < min_series_per_chunk)
113 stop("Too many tasks: less series in one task than min_series_per_chunk!")
ac1d4231 114
7b13d0c2 115 # Cluster coefficients in parallel (by nb_series_per_chunk)
56857861 116 indices_all = if (random) sample(nb_curves) else seq_len(nb_curves)
48108c39 117 indices_tasks = lapply(seq_len(ntasks), function(i) {
5c652979 118 upper_bound = ifelse( i<ntasks, min(nb_series_per_task*i,nb_curves), nb_curves )
56857861 119 indices_all[((i-1)*nb_series_per_task+1):upper_bound]
48108c39 120 })
e205f218 121 cl = parallel::makeCluster(ncores_tasks)
56857861 122 # 1000*K1 indices [if WER=="end"], or empty vector [if WER=="mix"] --> series on file
e205f218 123 indices = unlist( parallel::parLapply(cl, indices_tasks, function(inds) {
56857861
BA
124 indices_medoids = clusteringTask(inds,getCoefs,K1,nb_series_per_chunk,ncores_clust)
125 if (WER=="mix")
126 {
127 medoids2 = computeClusters2(
128 getSeries(indices_medoids), K2, getSeries, nb_series_per_chunk)
129 serialize(medoids2, synchrones_file, nb_series_per_chunk, sep, nbytes, endian)
130 return (vector("integer",0))
131 }
132 indices_medoids
e205f218
BA
133 }) )
134 parallel::stopCluster(cl)
3465b246 135
8702eb86 136 getRefSeries = getSeries
56857861 137 synchrones_file = paste(bin_dir,"synchrones",sep="") ; unlink(synchrones_file)
e205f218
BA
138 if (WER=="mix")
139 {
140 indices = seq_len(ntasks*K2)
141 #Now series must be retrieved from synchrones_file
56857861 142 getSeries = function(inds) getDataInFile(inds, synchrones_file, nbytes, endian)
e205f218
BA
143 #Coefs must be re-computed
144 unlink(coefs_file)
145 index = 1
146 repeat
147 {
148 series = getSeries((index-1)+seq_len(nb_series_per_chunk))
149 if (is.null(series))
150 break
8702eb86
BA
151 coefs_chunk = curvesToCoefs(series, wf)
152 serialize(coefs_chunk, coefs_file, nb_series_per_chunk, sep, nbytes, endian)
e205f218
BA
153 index = index + nb_series_per_chunk
154 }
155 }
0e2dce80
BA
156
157 # Run step2 on resulting indices or series (from file)
56857861
BA
158 indices_medoids = clusteringTask(
159 indices, getCoefs, K1, nb_series_per_chunk, ncores_tasks*ncores_clust)
8702eb86 160 computeClusters2(getSeries(indices_medoids),K2,getRefSeries,nb_series_per_chunk)
56857861
BA
161}
162
163# helper
8702eb86 164curvesToCoefs = function(series, wf)
56857861
BA
165{
166 L = length(series[1,])
167 D = ceiling( log2(L) )
168 nb_sample_points = 2^D
8702eb86 169 t( apply(series, 1, function(x) {
56857861
BA
170 interpolated_curve = spline(1:L, x, n=nb_sample_points)$y
171 W = wavelets::dwt(interpolated_curve, filter=wf, D)@W
172 rev( sapply( W, function(v) ( sqrt( sum(v^2) ) ) ) )
8702eb86 173 }) )
56857861
BA
174}
175
176# helper
177.toInteger <- function(x, condition)
178{
179 if (!is.integer(x))
180 tryCatch(
181 {x = as.integer(x)[1]},
182 error = function(e) paste("Cannot convert argument",substitute(x),"to integer")
183 )
184 if (!condition(x))
185 stop(paste("Argument",substitute(x),"does not verify condition",body(condition)))
186 x
cea14f3a 187}