| 1 | #include <math.h> |
| 2 | #include <stdlib.h> |
| 3 | |
| 4 | void ew_predict_noNA(double* X, double* Y, int* n_, int* K_, double* alpha_, int* grad_, double* weight) |
| 5 | { |
| 6 | int K = *K_; |
| 7 | int n = *n_; |
| 8 | double alpha = *alpha_; |
| 9 | int grad = *grad_; |
| 10 | |
| 11 | //at least two experts to combine: various inits |
| 12 | double invMaxError = 1. / 50; //TODO: magic number |
| 13 | double logK = log(K); |
| 14 | double initWeight = 1. / K; |
| 15 | for (int i=0; i<K; i++) |
| 16 | weight[i] = initWeight; |
| 17 | double* error = (double*)malloc(K*sizeof(double)); |
| 18 | double* cumError = (double*)calloc(K, sizeof(double)); |
| 19 | |
| 20 | //start main loop |
| 21 | for (int t=0; t<n; t++ < n) |
| 22 | { |
| 23 | if (grad) |
| 24 | { |
| 25 | double hatY = 0.; |
| 26 | for (int i=0; i<K; i++) |
| 27 | hatY += X[t*K+i] * weight[i]; |
| 28 | for (int i=0; i<K; i++) |
| 29 | error[i] = 2. * (hatY - Y[t]) * X[t*K+i]; |
| 30 | } |
| 31 | else |
| 32 | { |
| 33 | for (int i=0; i<K; i++) |
| 34 | { |
| 35 | double delta = X[t*K+i] - Y[t]; |
| 36 | error[i] = delta * delta; |
| 37 | /* if ((X[t*K+i] <= 30 && Y[t] > 30) || (X[t*K+i] > 30 && Y[t] <= 30)) |
| 38 | error[i] = 1.0; |
| 39 | else |
| 40 | error[i] = 0.0; |
| 41 | */ |
| 42 | } |
| 43 | } |
| 44 | for (int i=0; i<K; i++) |
| 45 | cumError[i] += error[i]; |
| 46 | |
| 47 | if (t < n-1 && !grad) |
| 48 | { |
| 49 | //weight update is useless |
| 50 | continue; |
| 51 | } |
| 52 | |
| 53 | //double eta = invMaxError * sqrt(8*logK/(t+1)); //TODO: good formula ? |
| 54 | double eta = invMaxError * 1. / (t+1); //TODO: good formula ? |
| 55 | for (int i=0; i<K; i++) |
| 56 | weight[i] = exp(-eta * cumError[i]); |
| 57 | double sumWeight = 0.0; |
| 58 | for (int i=0; i<K; i++) |
| 59 | sumWeight += weight[i]; |
| 60 | for (int i=0; i<K; i++) |
| 61 | weight[i] /= sumWeight; |
| 62 | //redistribute weights if alpha > 0 (all weights are 0 or more, sum > 0) |
| 63 | for (int i=0; i<K; i++) |
| 64 | weight[i] = (1. - alpha) * weight[i] + alpha/K; |
| 65 | } |
| 66 | |
| 67 | free(error); |
| 68 | free(cumError); |
| 69 | } |