## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## -----------------------------------------------------------------------------
library(SMMAL)

file_path <- system.file("extdata", "sample_data_withmissing.rds", package = "SMMAL")
dat <- readRDS(file_path)


file_path2 <- system.file("extdata", "semi_supervised_data.rds", package = "SMMAL")
data_loaded <- readRDS(file_path2)

## -----------------------------------------------------------------------------
  # Y and A are numeric vector 
  Y <- dat$Y
  A <- dat$A
  
  # S and X needs to be data frame
  S <- data.frame(dat$S)
  X <- data.frame(dat$X)

## -----------------------------------------------------------------------------
 SMMAL_output1 <- SMMAL(Y=Y,A=A,S=S,X=X)
 print(SMMAL_output1)

## -----------------------------------------------------------------------------
SMMAL_output2 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "xgboost")
print(SMMAL_output2)

## -----------------------------------------------------------------------------
SMMAL_output3 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "randomforest")
print(SMMAL_output3)

## -----------------------------------------------------------------------------
 SMMAL_output4 <- SMMAL(Y=Y,A=A,S=S,X=X, cf_model= "glm")
 print(SMMAL_output4)

## -----------------------------------------------------------------------------
 SMMAL_output5 <- SMMAL(Y=Y,A=A,S=S,X=X, custom_model_fun = SMMAL_ada_lasso)
 print(SMMAL_output5)

## -----------------------------------------------------------------------------
SMMAL_ada_lasso

## -----------------------------------------------------------------------------
str(data_loaded)

## -----------------------------------------------------------------------------
SMMAL_fold_predictions <-SMMAL_ada_lasso(
  X = data_loaded$X,
  Y = data_loaded$Y,
  X_full = data_loaded$X_full,
  foldid = data_loaded$foldid,
  foldid_labelled = data_loaded$foldid_labelled,
  sub_set = data_loaded$sub_set,
  labeled_indices = data_loaded$labeled_indices,
  nfold = data_loaded$nfold,
  log_loss = data_loaded$log_loss
)

str(SMMAL_fold_predictions)

