update the output to have the classification
[valse.git] / pkg / R / main.R
index 695a23f..89c4bcd 100644 (file)
@@ -131,9 +131,23 @@ print(tableauRecap)
 
   mod = as.character(tableauRecap[indModSel,1])
   listMod = as.integer(unlist(strsplit(mod, "[.]")))
-  if (plot){
-    print(plot_valse(models_list[[listMod[1]]][[listMod[2]]],n))
+  modelSel = models_list[[listMod[1]]][[listMod[2]]]
+  
+  ##Affectations
+  Gam = matrix(0, ncol = length(modelSel$pi), nrow = n)
+  for (i in 1:n){
+    for (r in 1:length(modelSel$pi)){
+      sqNorm2 = sum( (Y[i,]%*%modelSel$rho[,,r]-X[i,]%*%modelSel$phi[,,r])^2 )
+      Gam[i,r] = modelSel$pi[r] * exp(-0.5*sqNorm2)* det(modelSel$rho[,,r])
+    }
+  }
+  Gam = Gam/rowSums(Gam)
+  modelSel$affec = apply(Gam, 1,which.max)
+  modelSel$proba = Gam
+
+    if (plot){
+    print(plot_valse(modelSel,n))
   }
-  models_list[[listMod[1]]][[listMod[2]]]
   
+  return(modelSel)
 }