From 9fadef2bff80d4b0371962dea4b6de24086f230b Mon Sep 17 00:00:00 2001
From: emilie <emilie@devijver.org>
Date: Wed, 12 Apr 2017 12:54:11 +0200
Subject: [PATCH] update the output to have the classification

---
 pkg/R/main.R       | 20 +++++++++++++++++---
 pkg/R/plot_valse.R | 13 ++-----------
 2 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/pkg/R/main.R b/pkg/R/main.R
index 695a23f..89c4bcd 100644
--- a/pkg/R/main.R
+++ b/pkg/R/main.R
@@ -131,9 +131,23 @@ print(tableauRecap)
 
   mod = as.character(tableauRecap[indModSel,1])
   listMod = as.integer(unlist(strsplit(mod, "[.]")))
-  if (plot){
-    print(plot_valse(models_list[[listMod[1]]][[listMod[2]]],n))
+  modelSel = models_list[[listMod[1]]][[listMod[2]]]
+  
+  ##Affectations
+  Gam = matrix(0, ncol = length(modelSel$pi), nrow = n)
+  for (i in 1:n){
+    for (r in 1:length(modelSel$pi)){
+      sqNorm2 = sum( (Y[i,]%*%modelSel$rho[,,r]-X[i,]%*%modelSel$phi[,,r])^2 )
+      Gam[i,r] = modelSel$pi[r] * exp(-0.5*sqNorm2)* det(modelSel$rho[,,r])
+    }
+  }
+  Gam = Gam/rowSums(Gam)
+  modelSel$affec = apply(Gam, 1,which.max)
+  modelSel$proba = Gam
+
+    if (plot){
+    print(plot_valse(modelSel,n))
   }
-  models_list[[listMod[1]]][[listMod[2]]]
   
+  return(modelSel)
 }
diff --git a/pkg/R/plot_valse.R b/pkg/R/plot_valse.R
index 2c74554..120196d 100644
--- a/pkg/R/plot_valse.R
+++ b/pkg/R/plot_valse.R
@@ -48,20 +48,11 @@ plot_valse = function(model,n){
   print(gCov )
   
   ### proportions
-  Gam = matrix(0, ncol = K, nrow = n)
-  gam  = Gam
-  for (i in 1:n){
-    for (r in 1:K){
-      sqNorm2 = sum( (Y[i,]%*%model$rho[,,r]-X[i,]%*%model$phi[,,r])^2 )
-      Gam[i,r] = model$pi[r] * exp(-0.5*sqNorm2)* det(model$rho[,,r])
-    }
-    gam[i,] = Gam[i,] / sum(Gam[i,])
-  }
-  affec = apply(gam, 1,which.max)
   gam2 = matrix(NA, ncol = K, nrow = n)
   for (i in 1:n){
-    gam2[i, ] = c(gam[i, affec[i]], affec[i])
+    gam2[i, ] = c(model$Gam[i, model$affec[i]], model$affec[i])
   }
+  
   bp <- ggplot(data.frame(gam2), aes(x=X2, y=X1, color=X2, group = X2)) +
     geom_boxplot() + theme(legend.position = "none")+ background_grid(major = "xy", minor = "none")
   print(bp )
-- 
2.44.0