#include "EMGLLF.h" #include // TODO: comment on EMGLLF purpose void EMGLLF( // IN parameters const Real* phiInit, // parametre initial de moyenne renormalisé const Real* rhoInit, // parametre initial de variance renormalisé const Real* piInit, // parametre initial des proportions const Real* gamInit, // paramètre initial des probabilités a posteriori de chaque échantillon Int mini, // nombre minimal d'itérations dans l'algorithme EM Int maxi, // nombre maximal d'itérations dans l'algorithme EM Real gamma, // valeur de gamma : puissance des proportions dans la pénalisation pour un Lasso adaptatif Real lambda, // valeur du paramètre de régularisation du Lasso const Real* X, // régresseurs const Real* Y, // réponse Real tau, // seuil pour accepter la convergence // OUT parameters (all pointers, to be modified) Real* phi, // parametre de moyenne renormalisé, calculé par l'EM Real* rho, // parametre de variance renormalisé, calculé par l'EM Real* pi, // parametre des proportions renormalisé, calculé par l'EM Real* LLF, // log vraisemblance associé à cet échantillon, pour les valeurs estimées des paramètres Real* S, // additional size parameters mwSize n, // nombre d'echantillons mwSize p, // nombre de covariables mwSize m, // taille de Y (multivarié) mwSize k) // nombre de composantes dans le mélange { //Initialize outputs copyArray(phiInit, phi, p*m*k); copyArray(rhoInit, rho, m*m*k); copyArray(piInit, pi, k); zeroArray(LLF, maxi); //S is already allocated, and doesn't need to be 'zeroed' //Other local variables //NOTE: variables order is always [maxi],n,p,m,k Real* gam = (Real*)malloc(n*k*sizeof(Real)); copyArray(gamInit, gam, n*k); Real* b = (Real*)malloc(k*sizeof(Real)); Real* Phi = (Real*)malloc(p*m*k*sizeof(Real)); Real* Rho = (Real*)malloc(m*m*k*sizeof(Real)); Real* Pi = (Real*)malloc(k*sizeof(Real)); Real* gam2 = (Real*)malloc(k*sizeof(Real)); Real* pi2 = (Real*)malloc(k*sizeof(Real)); Real* Gram2 = (Real*)malloc(p*p*k*sizeof(Real)); Real* ps = (Real*)malloc(m*k*sizeof(Real)); Real* nY2 = (Real*)malloc(m*k*sizeof(Real)); Real* ps1 = (Real*)malloc(n*m*k*sizeof(Real)); Real* ps2 = (Real*)malloc(p*m*k*sizeof(Real)); Real* nY21 = (Real*)malloc(n*m*k*sizeof(Real)); Real* Gam = (Real*)malloc(n*k*sizeof(Real)); Real* X2 = (Real*)malloc(n*p*k*sizeof(Real)); Real* Y2 = (Real*)malloc(n*m*k*sizeof(Real)); gsl_matrix* matrix = gsl_matrix_alloc(m, m); gsl_permutation* permutation = gsl_permutation_alloc(m); Real* YiRhoR = (Real*)malloc(m*sizeof(Real)); Real* XiPhiR = (Real*)malloc(m*sizeof(Real)); Real dist = 0.0; Real dist2 = 0.0; Int ite = 0; Real EPS = 1e-15; Real* dotProducts = (Real*)malloc(k*sizeof(Real)); while (ite < mini || (ite < maxi && (dist >= tau || dist2 >= sqrt(tau)))) { copyArray(phi, Phi, p*m*k); copyArray(rho, Rho, m*m*k); copyArray(pi, Pi, k); // Calculs associes a Y et X for (mwSize r=0; r Real dotProduct = 0.0; for (mwSize u=0; u n*lambda*pow(pi[r],gamma)) phi[j*m*k+mm*k+r] = (n*lambda*pow(pi[r],gamma) - S[j*m*k+mm*k+r]) / Gram2[j*p*k+j*k+r]; else phi[j*m*k+mm*k+r] = -(n*lambda*pow(pi[r],gamma) + S[j*m*k+mm*k+r]) / Gram2[j*p*k+j*k+r]; } } } ///////////// // Etape E // ///////////// int signum; Real sumLogLLF2 = 0.0; for (mwSize i=0; i dotProducts[r] = 0.0; for (mwSize u=0; udata[u*m+v] = rho[u*m*k+v*k+r]; } gsl_linalg_LU_decomp(matrix, permutation, &signum); Real detRhoR = gsl_linalg_LU_det(matrix, signum); Gam[i*k+r] = pi[r] * detRhoR * exp(-0.5*dotProducts[r] + shift); sumLLF1 += Gam[i*k+r] / pow(2*M_PI,m/2.0); sumGamI += Gam[i*k+r]; } sumLogLLF2 += log(sumLLF1); for (mwSize r=0; r EPS ? Gam[i*k+r] / sumGamI : 0.0; } } //sum(pen(ite,:)) Real sumPen = 0.0; for (mwSize r=0; r Dist1) Dist1 = tmpDist; } } } //Dist2=max(max((abs(rho-Rho))./(1+abs(rho)))); Real Dist2 = 0.0; for (mwSize u=0; u Dist2) Dist2 = tmpDist; } } } //Dist3=max(max((abs(pi-Pi))./(1+abs(Pi)))); Real Dist3 = 0.0; for (mwSize u=0; u Dist3) Dist3 = tmpDist; } } //dist2=max([max(Dist1),max(Dist2),max(Dist3)]); dist2 = Dist1; if (Dist2 > dist2) dist2 = Dist2; if (Dist3 > dist2) dist2 = Dist3; ite++; } //free memory free(b); free(gam); free(Gam); free(Phi); free(Rho); free(Pi); free(ps); free(nY2); free(ps1); free(nY21); free(Gram2); free(ps2); gsl_matrix_free(matrix); gsl_permutation_free(permutation); free(XiPhiR); free(YiRhoR); free(gam2); free(pi2); free(X2); free(Y2); free(dotProducts); }