#include "EMGLLF.h" #include // TODO: don't recompute indexes every time...... void EMGLLF( // IN parameters const double* phiInit, // parametre initial de moyenne renormalisé const double* rhoInit, // parametre initial de variance renormalisé const double* piInit, // parametre initial des proportions const double* gamInit, // paramètre initial des probabilités a posteriori de chaque échantillon int mini, // nombre minimal d'itérations dans l'algorithme EM int maxi, // nombre maximal d'itérations dans l'algorithme EM double gamma, // valeur de gamma : puissance des proportions dans la pénalisation pour un Lasso adaptatif double lambda, // valeur du paramètre de régularisation du Lasso const double* X, // régresseurs const double* Y, // réponse double tau, // seuil pour accepter la convergence // OUT parameters (all pointers, to be modified) double* phi, // parametre de moyenne renormalisé, calculé par l'EM double* rho, // parametre de variance renormalisé, calculé par l'EM double* pi, // parametre des proportions renormalisé, calculé par l'EM double* LLF, // log vraisemblance associé à cet échantillon, pour les valeurs estimées des paramètres double* S, // additional size parameters int n, // nombre d'echantillons int p, // nombre de covariables int m, // taille de Y (multivarié) int k) // nombre de composantes dans le mélange { //Initialize outputs copyArray(phiInit, phi, p*m*k); copyArray(rhoInit, rho, m*m*k); copyArray(piInit, pi, k); zeroArray(LLF, maxi); //S is already allocated, and doesn't need to be 'zeroed' //Other local variables //NOTE: variables order is always [maxi],n,p,m,k double* gam = (double*)malloc(n*k*sizeof(double)); copyArray(gamInit, gam, n*k); double* b = (double*)malloc(k*sizeof(double)); double* Phi = (double*)malloc(p*m*k*sizeof(double)); double* Rho = (double*)malloc(m*m*k*sizeof(double)); double* Pi = (double*)malloc(k*sizeof(double)); double* gam2 = (double*)malloc(k*sizeof(double)); double* pi2 = (double*)malloc(k*sizeof(double)); double* Gram2 = (double*)malloc(p*p*k*sizeof(double)); double* ps = (double*)malloc(m*k*sizeof(double)); double* nY2 = (double*)malloc(m*k*sizeof(double)); double* ps1 = (double*)malloc(n*m*k*sizeof(double)); double* ps2 = (double*)malloc(p*m*k*sizeof(double)); double* nY21 = (double*)malloc(n*m*k*sizeof(double)); double* Gam = (double*)malloc(n*k*sizeof(double)); double* X2 = (double*)malloc(n*p*k*sizeof(double)); double* Y2 = (double*)malloc(n*m*k*sizeof(double)); gsl_matrix* matrix = gsl_matrix_alloc(m, m); gsl_permutation* permutation = gsl_permutation_alloc(m); double* YiRhoR = (double*)malloc(m*sizeof(double)); double* XiPhiR = (double*)malloc(m*sizeof(double)); double dist = 0.; double dist2 = 0.; int ite = 0; double EPS = 1e-15; double* dotProducts = (double*)malloc(k*sizeof(double)); while (ite < mini || (ite < maxi && (dist >= tau || dist2 >= sqrt(tau)))) { copyArray(phi, Phi, p*m*k); copyArray(rho, Rho, m*m*k); copyArray(pi, Pi, k); // Calculs associés a Y et X for (int r=0; r double dotProduct = 0.0; for (int u=0; u n*lambda*pow(pi[r],gamma)) phi[ai(j,mm,r,p,m,k)] = (n*lambda*pow(pi[r],gamma) - S[ai(j,mm,r,p,m,k)]) / Gram2[ai(j,j,r,p,p,k)]; else phi[ai(j,mm,r,p,m,k)] = -(n*lambda*pow(pi[r],gamma) + S[ai(j,mm,r,p,m,k)]) / Gram2[ai(j,j,r,p,p,k)]; } } } ///////////// // Etape E // ///////////// int signum; double sumLogLLF2 = 0.0; for (int i=0; i dotProducts[r] = 0.0; for (int u=0; udata[u*m+v] = rho[ai(u,v,r,m,m,k)]; } gsl_linalg_LU_decomp(matrix, permutation, &signum); double detRhoR = gsl_linalg_LU_det(matrix, signum); Gam[mi(i,r,n,k)] = pi[r] * detRhoR * exp(-0.5*dotProducts[r] + shift); sumLLF1 += Gam[mi(i,r,n,k)] / pow(2*M_PI,m/2.0); sumGamI += Gam[mi(i,r,n,k)]; } sumLogLLF2 += log(sumLLF1); for (int r=0; r EPS ? Gam[mi(i,r,n,k)] / sumGamI : 0.0; } } //sum(pen(ite,:)) double sumPen = 0.0; for (int r=0; r Dist1) Dist1 = tmpDist; } } } //Dist2=max(max((abs(rho-Rho))./(1+abs(rho)))); double Dist2 = 0.0; for (int u=0; u Dist2) Dist2 = tmpDist; } } } //Dist3=max(max((abs(pi-Pi))./(1+abs(Pi)))); double Dist3 = 0.0; for (int u=0; u Dist3) Dist3 = tmpDist; } } //dist2=max([max(Dist1),max(Dist2),max(Dist3)]); dist2 = Dist1; if (Dist2 > dist2) dist2 = Dist2; if (Dist3 > dist2) dist2 = Dist3; ite++; } //free memory free(b); free(gam); free(Gam); free(Phi); free(Rho); free(Pi); free(ps); free(nY2); free(ps1); free(nY21); free(Gram2); free(ps2); gsl_matrix_free(matrix); gsl_permutation_free(permutation); free(XiPhiR); free(YiRhoR); free(gam2); free(pi2); free(X2); free(Y2); free(dotProducts); }