# The following function computes the multiple correlation coefficient corrected for attenuation and accounting for non-independent error scores

# obs.mat, errs.mat, and test.mat all assume an nx3 matrix with columns for (x,y,z); x = dependent; y,z  = independent
# Inputs:
# obs.mat = observed data
# errs.mat = error scores
# Plus one of the following inputs:
# test.mat = data to compare to observed data to use test-retest to compute reliability values
# rels = vector of pre-computed reliability values

mult_corr_fun <- function(obs.mat, errs.mat, test.mat = NULL, rels = NULL){
  
  obs.x <- obs.mat[,1] 
  obs.y <- obs.mat[,2]   
  obs.z <- obs.mat[,3]
  
  errs.x <- errs.mat[,1] 
  errs.y <- errs.mat[,2]   
  errs.z <- errs.mat[,3]
  
  # Observed correlations
  rho.xy <- cor(obs.x, obs.y, use = "complete.obs")  
  rho.xz <- cor(obs.x, obs.z, use = "complete.obs")
  rho.yz <- cor(obs.y, obs.z, use = "complete.obs") 
  
  # Correlation of the errors
  rho.errs.xy <- cor(errs.x, errs.y, use = "complete.obs")
  rho.errs.xz <- cor(errs.x, errs.z, use = "complete.obs")
  rho.errs.yz <- cor(errs.y, errs.z, use = "complete.obs")
  
  # If given reliability values
  if(is.null(test.mat) == TRUE){
    rel.x <- rels[1]
    rel.y <- rels[2]
    rel.z <- rels[3]
  }else{
    # Test re-test reliabilities
    test.x <- test.mat[,1]
    test.y <- test.mat[,2]
    test.z <- test.mat[,3]
    rel.x <- cor(obs.x, test.x, use = "complete.obs")
    rel.y <- cor(obs.y, test.y, use = "complete.obs")
    rel.z <- cor(obs.z, test.z, use = "complete.obs")
  }
  
  # Observed multiple correlation 
  r.obs <- sqrt((rho.xy^2 + rho.xz^2 - 2*rho.xy*rho.xz*rho.yz)*(1 - rho.yz^2)^(-1))
  
  # Spearman's correction for attenuation
  rho.xy.spear <- rho.xy/sqrt(rel.x*rel.y) 
  rho.xz.spear <- rho.xz/sqrt(rel.x*rel.z)
  rho.yz.spear <- rho.yz/sqrt(rel.y*rel.z) 
  r.spearman <- sqrt((rho.xy.spear^2 + rho.xz.spear^2 - 2*rho.xy.spear*rho.xz.spear*rho.yz.spear)*(1 - rho.yz.spear^2)^(-1))
  
  # NEW multiple correlation coefficient
  num <- rel.z*(rho.xy - rho.errs.xy*sqrt(1 - rel.x)*sqrt(1 - rel.y))^2 + 
    rel.y*(rho.xz - rho.errs.xz*sqrt(1 - rel.x)*sqrt(1 - rel.z))^2 - 
    2*(rho.xy - rho.errs.xy*sqrt(1 - rel.x)*sqrt(1 - rel.y))*
    (rho.xz - rho.errs.xz*sqrt(1 - rel.x)*sqrt(1 - rel.z))*
    (rho.yz - rho.errs.yz*sqrt(1 - rel.y)*sqrt(1 - rel.z))
  
  denom <- rel.x*(rel.y*rel.z - (rho.yz - rho.errs.yz*sqrt(1 - rel.y)*sqrt(1 - rel.z))^2)
  
  r.new <- sqrt(num/denom)
  
  return(list(r.obs = r.obs, r.spearman = r.spearman, r.new = r.new,
              rel.x = rel.x, rel.y = rel.y, rel.z = rel.z,
              rho.xy = rho.xy, rho.xz = rho.xz, rho.yz = rho.yz,
              rho.errs.xy = rho.errs.xy, rho.errs.xz = rho.errs.xz, rho.errs.yz = rho.errs.yz
  ))
}