SAEM algorithm parameters

@Diptimangal @adunn34 Perhaps we learn SAEM together, as we might be by ourselves. Here is what I have done. If you followed my nibr data discussion. I used that dataset to run ML (FOCE) and EM (SAEM) estimations using this code. Might I request you to see if you are able to reproduce my results? rectify any errors? improve? I was unable to make some informative plots (FOCE estimates vs SAEM estimates; ipred FOCE vs ipred SAEM, pred…). May be you will be able to.

using Pumas
using PumasUtilities
using Dates
using CairoMakie
using AlgebraOfGraphics
using DataFramesMeta
using Random

#################
# TIME in hours
# CONC in ug/L 
# AMT_UG in ug 
# WEIGHTB in kg 
# DOSE in mg
#################

pkdata = CSV.read("./practice/nibr/SAD/data/sad_pk.csv", DataFrame)

pop_nmle  = read_pumas(pkdata;
                        observations   =   [:CONC],
                        id             =   :ID,
                        time           =   :TIME,
                        amt            =   :AMT_UG,
                        cmt            =   :CMT,
                        evid           =   :EVID,
                        covariates     =   [:WEIGHTB, :SEX, :DOSE]
                      )


ka1cmt = @model begin
  
    @param begin
      tvka     ∈ RealDomain(lower=0.001)
      tvcl     ∈ RealDomain(lower=0.001)
      tvvc     ∈ RealDomain(lower=0.001)
      Ω        ∈ PDiagDomain(3)
      σ²_add   ∈ RealDomain(lower=0.001)
    end
  
    @random begin
      η        ~ MvNormal(Ω)
    end

    @covariates WEIGHTB DOSE SEX 
  
    @pre begin
      Ka       = tvka  *  exp(η[1])
      CL       = tvcl  *  exp(η[2])
      Vc       = tvvc  *  exp(η[3])
    end
  
    @dynamics Depots1Central1
  
    @derived begin
      cp         = @. (Central/Vc)
      CONC       ~ @. Normal(cp, sqrt(σ²_add))
    end
end

####
#tvka = 1/hr
#tvcl = L/hr
#tvvc = L
### 

param = (tvka    = 0.5,
         tvcl    = 6000.0,
         tvvc    = 200000.0,
         Ω       = Diagonal([0.05,0.05,0.05]),
         σ²_add  = 2)


# Start initial estimates explorer app
ee_1cmp = explore_estimates(ka1cmt, 
                     pop_nmle, 
                     param)

ka1cmt_fit = fit(ka1cmt, 
                    pop_nmle, 
                    param, 
                    #constantcoef=(tvka = 10,),
                    Pumas.FOCEI())

#
ka1cmt_inspect = inspect(ka1cmt_fit)
ka1cmt_evaluate = evaluate_diagnostics(ka1cmt_inspect)
ka1cmt_inspect_df = DataFrame(ka1cmt_inspect)
sort!(ka1cmt_inspect_df, ([:id,:time]))

ka1cmt_infer   = infer(ka1cmt_fit)
ka1cmt_icoef = reduce(vcat, DataFrame.(icoef(ka1cmt_fit)))
@rtransform! ka1cmt_icoef method = "foce"


################
## SAEM Model 
################

ka1cmt_saem = @emmodel begin
  @param begin
    tvka ~ 1 | LogNormal # RealDomain(lower=0.01)
    tvcl ~ 1 | LogNormal # RealDomain(lower=0.2)
    tvvc ~ 1 | LogNormal # RealDomain(lower=0.1)
  end
  @random begin
    η1 ~ 1 | Normal
    η2 ~ 1 | Normal
    η3 ~ 1 | Normal
  end
  @covariance (1, 1, 1)
  @pre begin
    Ka = tvka * exp(η1)
    CL = tvcl * exp(η2)
    Vc = tvvc * exp(η3)
  end
  @dynamics Depots1Central1
  @post begin
    cp = @. Central / Vc
  end

  #@derived begin # DERIVED BLOCK DOES NOT WORK IN SAEM - ERROR: LoadError: LoadError: "Macro block @derived not supported."
  #  cp         = @. (Central/Vc)
  #  CONC       ~ @. Normal(cp, sqrt(σ²_add)) #THIS FORMAT OF RESIDUAL ERROR MODEL DOES NOT WORK FOR SAEM
  #end

  @error begin
    CONC ~ Normal(cp)
  end
end

param_saem = (
  tvka = 0.25,
  tvcl = 6000.0,
  tvvc = 200000.0,
  # Random effects must be initialized as well
  η1 = 0.0,
  η2 = 0.0,
  η3 = 0.0,
  Ω       = (0.05,0.05,0.05),
  σ²_add  = 2
)

## Fit base model with SAEM
ka1cmt_em_fit = fit(ka1cmt_saem, pop_nmle, param_saem, Pumas.SAEM())

rngv = [MersenneTwister(1941964947i + 1) for i ∈ 1:Threads.nthreads()]
ka1cmt_em_fit = fit(ka1cmt_saem, pop_nmle, param_saem, Pumas.SAEM(), ensemblealg = EnsembleThreads(), rng=rngv)

ka1cmt_em_inspect = inspect(ka1cmt_em_fit)
ka1cmt_em_evaluate = evaluate_diagnostics(ka1cmt_em_inspect)
ka1cmt_em_inspect_df = DataFrame(ka1cmt_em_inspect)
sort!(ka1cmt_em_inspect_df, ([:id,:time]))

ka1cmt_em_infer   = infer(ka1cmt_em_fit,level = 0.95)
ka1cmt_em_icoef = reduce(vcat, DataFrame.(icoef(ka1cmt_em_fit)))

@rtransform! ka1cmt_em_icoef method = "saem"

###########
ka1cmt_icoef = @chain ka1cmt_icoef begin
  @transform(:CLfoce = :CL)
  @transform(:Vcfoce = :Vc)
  @transform(:Kafoce = :Ka)
  @select(:id,:CLfoce,:Vcfoce,:Kafoce)
end

ka1cmt_em_icoef = @chain ka1cmt_em_icoef begin
  @transform(:CLsaem = :CL)
  @transform(:Vcsaem = :Vc)
  @transform(:Kasaem = :Ka)
  @select(:id,:CLsaem,:Vcsaem,:Kasaem)
end

icoef_foce_saem = innerjoin(ka1cmt_icoef,ka1cmt_em_icoef, on=:id,makeunique=true)

icoef_foce_saem = @chain icoef_foce_saem begin
  @transform(:dCL = (:CLsaem - :CLfoce)*100 ./ :CLfoce)
  @transform(:dVc = (:Vcsaem - :Vcfoce)*100 ./ :Vcfoce)
  @transform(:dKa = (:Kasaem - :Kafoce)*100 ./ :Kafoce)
  @select(:id,:dCL,:dVc,:dKa)
end

summarize(icoef_foce_saem;
          parameters = [:dCL,:dVc,:dKa],
          stats = [NCA.extrema, NCA.mean, NCA.std])

#
inspect_foce_saem = innerjoin(ka1cmt_inspect_df,ka1cmt_em_inspect_df, on=[:id,:time],makeunique=true)
inspect_foce_saem = @chain inspect_foce_saem begin
  @transform(:dCONC_ipred = (:CONC_ipred - :CONC_ipred_1)*100 ./ :CONC_ipred)
  @transform(:dCONC_pred  = (:CONC_pred - :CONC_pred_1)*100 ./ :CONC_pred)
  @select(:id,:dCONC_ipred,:dCONC_pred)
end

summarize(inspect_foce_saem; # I CANNOT FIGURE OUT THE ERROR WITH THIS BLOCK
          parameters = [:dCONC_ipred,:dCONC_pred],
          stats = [NCA.extrema, NCA.mean, NCA.std])

julia> ka1cmt_fit # FOCE FIT
FittedPumasModel

Successful minimization: true

Likelihood approximation: Pumas.FOCE
Log-likelihood value: -1291.1646
Number of subjects: 50
Number of parameters: Fixed Optimized
0 7
Observation records: Active Missing
CONC: 600 0
Total: 600 0


       Estimate

tvka 1.3204
tvcl 10227.0
tvvc 51450.0
Ω₁,₁ 0.21713
Ω₂,₂ 0.10226
Ω₃,₃ 0.22723
σ²_add 3.3067

julia> ka1cmt_em_fit # SAEM FIT
FittedPumasEMModel


     Estimate

tvka 1.2853
tvcl 8980.0
tvvc 47547.0
η1 -0.029223
η2 0.10481
η3 0.084683
Ω₁,₁ 0.023733
Ω₂,₂ 0.021677
Ω₃,₃ 0.1683
σ 2.0323