merge with remote
[valse.git] / pkg / src / sources / EMGLLF.c
index d2f5a8e..b77f24a 100644 (file)
@@ -16,7 +16,7 @@ void EMGLLF_core(
        Real lambda, // valeur du paramètre de régularisation du Lasso
        const Real* X, // régresseurs
        const Real* Y, // réponse
-       Real tau, // seuil pour accepter la convergence
+       Real eps, // seuil pour accepter la convergence
        // OUT parameters (all pointers, to be modified)
        Real* phi, // parametre de moyenne renormalisé, calculé par l'EM
        Real* rho, // parametre de variance renormalisé, calculé par l'EM
@@ -39,6 +39,7 @@ void EMGLLF_core(
 
        //Other local variables: same as in R
        Real* gam = (Real*)malloc(n*k*sizeof(Real));
+       Real* logGam = (Real*)malloc(k*sizeof(Real));
        copyArray(gamInit, gam, n*k);
        Real* Gram2 = (Real*)malloc(p*p*k*sizeof(Real));
        Real* ps2 = (Real*)malloc(p*m*k*sizeof(Real));
@@ -47,7 +48,6 @@ void EMGLLF_core(
        Real* Y2 = (Real*)malloc(n*m*k*sizeof(Real));
        *llh = -INFINITY;
        Real* pi2 = (Real*)malloc(k*sizeof(Real));
-       const Real EPS = 1e-15;
        // Additional (not at this place, in R file)
        Real* gam2 = (Real*)malloc(k*sizeof(Real));
        Real* sqNorm2 = (Real*)malloc(k*sizeof(Real));
@@ -300,19 +300,26 @@ void EMGLLF_core(
                                        sqNorm2[r] += (YiRhoR[u]-XiPhiR[u]) * (YiRhoR[u]-XiPhiR[u]);
                        }
 
-                       Real sumGamI = 0.;
+                       // Update gam[,]; use log to avoid numerical problems
+                       Real maxLogGam = -INFINITY;
                        for (int r=0; r<k; r++)
                        {
-                               gam[mi(i,r,n,k)] = pi[r] * exp(-.5*sqNorm2[r]) * detRho[r];
-                               sumGamI += gam[mi(i,r,n,k)];
+                               logGam[r] = log(pi[r]) - .5 * sqNorm2[r] + log(detRho[r]);
+                               if (maxLogGam < logGam[r])
+                                       maxLogGam = logGam[r];
                        }
-
-                       sumLogLLH += log(sumGamI) - log(gaussConstM);
-                       if (sumGamI > EPS) //else: gam[i,] is already ~=0
+                       Real norm_fact = 0.;
+                       for (int r=0; r<k; r++)
                        {
-                               for (int r=0; r<k; r++)
-                                       gam[mi(i,r,n,k)] /= sumGamI;
+                               logGam[r] = logGam[r] - maxLogGam; //adjust without changing proportions
+                               gam[mi(i,r,n,k)] = exp(logGam[r]); //gam[i, ] <- exp(logGam)
+                               norm_fact += gam[mi(i,r,n,k)]; //norm_fact <- sum(gam[i, ])
                        }
+                       // gam[i, ] <- gam[i, ] / norm_fact
+                       for (int r=0; r<k; r++)
+                               gam[mi(i,r,n,k)] /= norm_fact;
+
+                       sumLogLLH += log(norm_fact) - log(gaussConstM);
                }
 
                //sumPen = sum(pi^gamma * b)
@@ -320,9 +327,9 @@ void EMGLLF_core(
                for (int r=0; r<k; r++)
                        sumPen += pow(pi[r],gamma) * b[r];
                Real last_llh = *llh;
-               //llh = -sumLogLLH/n + lambda*sumPen
-               *llh = -invN * sumLogLLH + lambda * sumPen;
-               Real dist = ite==1 ? *llh : (*llh - last_llh) / (1. + fabs(*llh));
+               //llh = -sumLogLLH/n #+ lambda*sumPen
+               *llh = -invN * sumLogLLH; //+ lambda * sumPen;
+               Real dist = ( ite==1 ? *llh : (*llh - last_llh) / (1. + fabs(*llh)) );
 
                //Dist1 = max( abs(phi-Phi) / (1+abs(phi)) )
                Real Dist1 = 0.;
@@ -372,7 +379,7 @@ void EMGLLF_core(
                if (Dist3 > dist2)
                        dist2 = Dist3;
 
-               if (ite >= mini && (dist >= tau || dist2 >= sqrt(tau)))
+               if (ite >= mini && (dist >= eps || dist2 >= sqrt(eps)))
                        break;
        }
 
@@ -394,6 +401,7 @@ void EMGLLF_core(
        //free memory
        free(b);
        free(gam);
+       free(logGam);
        free(Phi);
        free(Rho);
        free(Pi);