#!/usr/bin/env bash
set -eo pipefail

printf -v args ' %q' $@
src=${BASH_SOURCE}
INVOCATION=$([ $# -ne 0 ] && echo $src $args || echo $src)

medaka_version=$(medaka --version)
modeldata=()
while read -r line; do
    modeldata+=("$line")
done < <(medaka tools list_models)
# 0: Available models
# 1: default consensus model
# 2: default SNP model
# 3: default variant model

function follow_link {
  python -c 'import os,sys; print(os.path.realpath(sys.argv[1]))' $1
}

declare -r WHATSHAPOPTS="--ignore-read-groups"

# Variables
OUTPUT="medaka_variant"
THREADS=1
BATCH_SIZE=100
REGIONS=""
iflag=false
fflag=false
PHASEOUT=false
DELETE=false
SNPMODEL=${modeldata[2]##* }
MODEL=${modeldata[3]##* }
INDEL_THRESHOLD=9
SNP_THRESHOLD=8
FILTER_FLAG=true
STOPAFTERROUND0=false
HAPLOID2DIPLOID_OPTS=""
SAMPLE="SAMPLE"

usage="
${medaka_version}
-------------

Variant calling via neural networks. The input bam should be aligned to the
reference against which to call variants. Note: that although configurable
it is unlikely that the model arguments should be changed from their defaults.

$(basename "$0") [-h] -i <bam>

    -h  show this help text.
    -i  input bam of reads aligned to ref. Read groups are currently ignored,
        so the bam should only contain reads from a single sample.
    -f  input fasta input reference (required).
    -r  region string(s). If providing multiple regions, wrap them in quotes.
        If not provided, will process all contigs in bam. 
    -o  output folder (default: ${OUTPUT}).
    -s  medaka model for initial SNP calling from mixed reads prior to phasing,
        (default: ${SNPMODEL}).
        ${modeldata[0]}.
        Alternatively a .hdf file from 'medaka train'. 
    -m  medaka model for final variant calling from phased reads,
        (default: ${MODEL}).
        ${modeldata[0]}.
        Alternatively a .hdf file from 'medaka train'. 
    -t  number of threads with which to create features (default: 1).
    -p  output phased vcf.
    -b  batchsize, controls memory use (default: ${BATCH_SIZE}).
    -d  delete intermediate files. (default: keep).
    -N  threshold for filtering indels in final VCF (default: ${INDEL_THRESHOLD})
    -P  threshold for filtering SNPs in final VCF (default: ${SNP_THRESHOLD})
    -U  Avoid filtering of final VCF (default: do filter)
    -S  stop after initial SNP calling from mixed reads prior to phasing.
    -l  split MNPs into SNPs.
    -n  set this sample name in the output VCF and phased bam."


while getopts ':hi::f:r:o:m:s:T:t:pb:N:P:UdSn:l' option; do
  case "$option" in
    h  ) echo "$usage" >&2; exit;;
    i  ) iflag=true; CALLS2REF=$(follow_link $OPTARG);;
    f  ) fflag=true; REF=$(follow_link $OPTARG);;
    r  ) REGIONS="$OPTARG";;
    o  ) OUTPUT=$OPTARG;;
    m  ) MODEL=$(medaka tools resolve_model --model $OPTARG);;
    s  ) SNPMODEL=$(medaka tools resolve_model --model $OPTARG);;
    t  ) THREADS=$OPTARG;;
    p  ) PHASEOUT=true;;
    b  ) BATCH_SIZE=$OPTARG;;
    d  ) DELETE=true;;
    N  ) INDEL_THRESHOLD=$OPTARG;;
    P  ) SNP_THRESHOLD=$OPTARG;;
    U  ) FILTER_FLAG=false;;
    S  ) STOPAFTERROUND0=true;;
    n  ) SAMPLE=$OPTARG;;
    l  ) HAPLOID2DIPLOID_OPTS="${HAPLOID2DIPLOID_OPTS} --split_mnp";;
    \? ) echo "Invalid option: -${OPTARG}." >&2; exit 1;;
    :  ) echo "Option -$OPTARG requires an argument." >&2; exit 1;;
  esac
done
shift $(($OPTIND - 1))


if ! $iflag; then
  echo "$usage" >&2;
  echo "" >&2;
  echo "-i must be specified." >&2;
  exit 1;
fi

if ! $fflag; then
  echo "$usage" >&2;
  echo "" >&2;
  echo "-f must be specified." >&2;
  exit 1;
fi

run_extract_region_from_bam () {
    echo "+ Extracting region from bam"
    local BAMIN=$1
    local BAMOUT=$2
    local REGIONS=$3

    exit_if_file_does_not_exist ${BAMIN} ${BAMIN}.bai
    if [[ ! -e ${BAMOUT} ]]; then
        echo "- Extracting regions $REGIONS from bam ${BAMIN} "
        samtools view -b -h ${BAMIN} -@ ${THREADS} -o ${BAMOUT} ${REGIONS} \
            || (echo "* Failed to extract regions from bam." && exit 1)
    else 
        echo "- Not extracting regions from bam, ${BAMOUT} exists."
    fi
    run_samtools_index ${BAMOUT}
}

run_medaka_consensus () {
    echo "+ Running consensus calling"
    local BAM=$1
    local PROBS=$2
    local MEDAKAMODEL=$3
    local EXTRAOPTS=$4
    if [ -n "${REGIONS}" ]; then
        local REG_OPT="--regions ${REGIONS}"
    else
        local REG_OPT=""
    fi
    exit_if_file_does_not_exist ${BAM} ${BAM}.bai
    if [[ ! -e ${PROBS} ]]; then
        echo "- Running medaka consensus with ${BAM} --model ${MEDAKAMODEL} ${EXTRAOPTS}"
        medaka consensus ${BAM} ${PROBS} --model ${MEDAKAMODEL} --batch_size ${BATCH_SIZE} ${REG_OPT} --threads ${THREADS} ${EXTRAOPTS} \
            || (echo "* Failed to run medaka consensus." && exit 1)
    else 
        echo "- Not running medaka consensus, ${PROBS} exists."
    fi
}

run_medaka_snp () {
    echo "+ Running snp calling"
    local PROBS=$1
    local VCF=$2
    local EXTRAOPTS=$4
    if [ -n "${REGIONS}" ]; then
        local REG_OPT="--regions ${REGIONS}"
    else
        local REG_OPT=""
    fi
    exit_if_file_does_not_exist ${PROBS} ${REF}
    if [[ ! -e ${VCF} ]]; then
        echo "- Running medaka snp"
        medaka snp ${REF} ${PROBS} ${VCF} ${REG_OPT} ${EXTRAOPTS} \
            || (echo "* Failed to call snps from consensus probabilities." && exit 1)
        echo "  VCF written to ${OUTPUT}/${VCF}."
    else
        echo "- Using existing output:  ${OUTPUT}/${VCF}."
    fi
}

run_medaka_variant () {
    echo "+ Running variant calling"
    local PROBS=$1
    local VCF=$2
    local EXTRAOPTS=$4
    if [ -n "${REGIONS}" ]; then
        local REG_OPT="--regions ${REGIONS}"
    else
        local REG_OPT=""
    fi
    exit_if_file_does_not_exist ${PROBS} ${REF}
    if [[ ! -e ${VCF} ]]; then
        echo "- Running medaka variant"
        medaka variant ${REF} ${PROBS} ${VCF} ${REG_OPT} ${EXTRAOPTS} \
            || (echo "* Failed to call variants from consensus probabilities." && exit 1)
        echo "  VCF written to ${OUTPUT}/${VCF}."
    else
        echo "- Using existing output:  ${OUTPUT}/${VCF}."
    fi
}

run_medaka_haploid2diploid () {
    echo "+ Merging vcfs to diploid vcf"
    # by default returns a phased vcf
    local VCF1=$1
    local VCF2=$2
    local REF=$3
    local MERGED=$4
    local OPTS=$5

    exit_if_file_does_not_exist ${VCF1} ${VCF2} ${REF}
    if [[ ! -e ${MERGED} ]]; then
        echo "- Running medaka haploid2diploid "
        medaka tools haploid2diploid ${VCF1} ${VCF2} ${REF} ${MERGED} ${OPTS} \
            || (echo "* Failed to run medaka haploid2diploid." && exit 1)
        echo "  VCF written to ${OUTPUT}/${MERGED}."
    else
        echo "- Using existing output:  ${OUTPUT}/${MERGED}."
    fi
}

run_phase_vcf () {
    echo "+ Phasing vcf"
    local VCFIN=$1 
    local VCFOUT=$2
    exit_if_file_does_not_exist ${VCFIN}
    run_samtools_faidx ${REF}
    if ! grep -c -m 1 -v -q "#" ${UNPHASEDVCF}; then
        echo "- Input vcf has no records, writing empty vcf"
        cp ${VCFIN} ${VCFOUT}
    elif [[ ! -e ${VCFOUT} ]]; then
        echo "- Running whatshap phase for: ${VCFIN} > ${VCFOUT}."
        whatshap phase --reference ${REF} -o ${VCFOUT} ${VCFIN} ${CALLS2REF} ${WHATSHAPOPTS} \
            || (echo "* Failed to phase variants in ${VCFIN}." && exit 1)
        echo "  Phased VCF written to ${OUTPUT}/${VCFOUT}."
    else
        echo "- Using existing output:  ${OUTPUT}/${VCFOUT}."
    fi
}

run_whatshap_tag () {
    echo "+ Tagging reads"
    local VCFIN=$1
    local BAMOUT=$2
    if [[ ! -e ${BAMOUT} ]]; then
        echo "- Running whatshap tag for: ${VCFIN} > ${BAMOUT}"
	# Leaving out -o completely did not work for me
        whatshap haplotag -o /dev/stdout --reference ${REF} ${VCFIN} ${CALLS2REF} ${WHATSHAPOPTS} \
	    | samtools addreplacerg -O BAM -r "@RG\tID:${SAMPLE}\tSM:${SAMPLE}" - > ${BAMOUT} \
            || (echo "Failed to partition reads in round 0." && exit 1)
        echo "  Tagged reads written to ${OUTPUT}/${BAMOUT}."
    else
        echo "- Using existing output:  ${OUTPUT}/${BAMOUT}."
    fi
    run_samtools_index ${BAMOUT}
}

run_bgzip () {
    echo "+ Compressing vcf"
    local FILEIN=$1
    local FILEOUT=$2
    exit_if_file_does_not_exist ${FILEIN}
    if [[ ! -e ${FILEOUT} ]]; then
        echo "- Compressing ${FILEIN}."
        bgzip -c ${FILEIN} > ${FILEOUT} \
            || (echo "* Failed to compress ${FILEIN}." && exit 1)
        echo "  Compressed file written to ${OUTPUT}/${FILEOUT}."
    else
        echo "- Using existing output:  ${OUTPUT}/${FILEOUT}."
    fi
    run_tabix ${FILEOUT} 
}

run_tabix () {
    echo "+ Indexing vcf"
    local FILEIN=$1
    local FILEOUT=${1}.tbi
    exit_if_file_does_not_exist ${FILEIN}
    if [[ ! -e ${FILEOUT} ]]; then
        echo "- Indexing ${FILEIN}."
        tabix -p vcf ${FILEIN} \
            || (echo "* Failed to compress ${FILEIN}." && exit 1)
        echo "  Indexed file written to ${OUTPUT}/${FILEOUT}."
    else
        echo "- Using existing output:  ${OUTPUT}/${FILEOUT}."
    fi
}

run_samtools_faidx () {
    echo "+ Indexing fasta"
    local FILEIN=$1
    local FILEOUT=${1}.fai
    exit_if_file_does_not_exist ${FILEIN}
    if [[ ! -e ${FILEOUT} ]]; then
        echo " - Indexing ${FILEIN}."
        samtools faidx ${FILEIN} \
            || (echo " * Failed to index ${FILEIN}." && exit 1)
        echo "  Created index ${OUTPUT}/${FILEOUT}."
    else
        echo "- Using existing output:  ${OUTPUT}/${FILEOUT}."
    fi
}

run_samtools_index () {
    echo "+ Indexing bam"
    local BAM=$1
    local INDEX=${BAM}.bai
    exit_if_file_does_not_exist ${BAM}

    if [[ ! -e ${INDEX} ]]; then
        echo "- Running samtools index on ${BAM}."
        samtools index -@ ${THREADS} ${BAM}  \
            || (echo "* Failed to index bam ${BAM}." && exit 1)
        echo "  Bam index in ${OUTPUT}/${INDEX}."
    else
        echo "- Using existing output:  ${OUTPUT}/${INDEX}."
    fi
}

exit_if_file_does_not_exist () {
    for FILE in "$@"; do
        if [[ ! -e ${FILE} ]]; then
            echo "* Could not find file ${FILE}, exiting." && exit 1
        fi
    done
}


###
# Set-up
echo ""
echo "++ Preparing data"
echo "- Checking program versions"
medaka_version_report || exit 1

if [[ ! -e ${OUTPUT} ]]; then
  mkdir -p ${OUTPUT}
else
  echo "- Warning: Output ${OUTPUT} already exists, may use old results."
fi

cd ${OUTPUT}

# check input bam, bam index and ref exist
exit_if_file_does_not_exist ${CALLS2REF} ${CALLS2REF}.bai ${REF}

# if regions have been provided, extract that region from the input bam
# (whatshap currently does not have a region string option). 

if [ -n "${REGIONS}" ]; then
    CALLS2REF_EXTRACTED=$(basename ${CALLS2REF})
    run_extract_region_from_bam ${CALLS2REF} ${CALLS2REF_EXTRACTED} ${REGIONS}
    CALLS2REF=${CALLS2REF_EXTRACTED}
fi


###
# First pass to haplotag reads
echo ""
echo "++ Running first pass to haplotag reads."
ROUND=0

CONSENSUSPROBS=round_${ROUND}_hap_mixed_probs.hdf
UNPHASEDVCF=round_${ROUND}_hap_mixed_unphased.vcf
run_medaka_consensus ${CALLS2REF} ${CONSENSUSPROBS} ${SNPMODEL}
run_medaka_snp ${CONSENSUSPROBS} ${UNPHASEDVCF}

# Set correct sample name
sed -i "s#\tSAMPLE#\t${SAMPLE}#" ${UNPHASEDVCF}

# check empty vcf, write empty final output
UNFILTERED=round_1_unfiltered.vcf
FINAL=round_1.vcf

if ! grep -c -m 1 -v -q "#" ${UNPHASEDVCF}; then
    echo "- First pass vcf has no records, writing empty vcf and finishing."
    cp ${UNPHASEDVCF} ${FINAL}  # this gets us the header
elif ${STOPAFTERROUND0}; then
    echo "- Stopping after first pass, copying ${UNPHASEDVCF} to ${FINAL}."
    echo "- Final VCF written to ${OUTPUT}/${FINAL}"
    cp ${UNPHASEDVCF} ${FINAL}
    exit 0
else
    PHASEDVCF=round_${ROUND}_hap_mixed_phased.vcf
    PHASEDVCFGZ=${PHASEDVCF}.gz
    PHASEDBAM=round_${ROUND}_hap_mixed_phased.bam
    run_phase_vcf ${UNPHASEDVCF} ${PHASEDVCF} 
    run_bgzip ${PHASEDVCF} ${PHASEDVCFGZ}
    run_whatshap_tag ${PHASEDVCFGZ} ${PHASEDBAM}
    
    ###
    # Second pass to call haplotypes and merge to form phased diploid vcf
    echo ""
    echo "++ Running second pass to produce variant calls."
    ROUND=1
    for HAP in 1 2; do
        PROBS=round_${ROUND}_hap_${HAP}_probs.hdf
        VCF=round_${ROUND}_hap_${HAP}.vcf
        run_medaka_consensus ${PHASEDBAM} ${PROBS} ${MODEL} \
            "--tag_name HP --tag_value ${HAP} --tag_keep_missing"
        run_medaka_variant ${PROBS} ${VCF}
    done;
    
    run_medaka_haploid2diploid \
        round_${ROUND}_hap_1.vcf round_${ROUND}_hap_2.vcf \
        ${REF} ${UNFILTERED} ${HAPLOID2DIPLOID_OPTS}

    # Add invocation of this script to VCF header
    INVOCATION_DATE="##CL=${INVOCATION}; $(date)"
    sed -i "/^#CHROM.*/i ${INVOCATION_DATE}" ${UNFILTERED}  
    # Set correct sample name
    sed -i "s#\tSAMPLE#\t${SAMPLE}#" ${UNFILTERED}

    if $FILTER_FLAG; then
        # Filter resulting vcf
        FILTER="(((TYPE=\"snp\" | TYPE=\"mnp\") & (QUAL > ${SNP_THRESHOLD})) | ((TYPE!=\"snp\" & TYPE!=\"mnp\") & (QUAL > ${INDEL_THRESHOLD})))"
        bcftools filter --threads $THREADS -s lowqual -i "${FILTER}" ${UNFILTERED} > ${FINAL}
        # pyvcf used in whatshap<=0.18 can't parse the ##FILTER=<ID=lowqual header line added by bcftools filter which contains escaped quotes.
        sed -i 's/\\"//g' ${FINAL}
    else
        echo "Skipping filtering of VCF (-U option)"
        mv ${UNFILTERED} ${FINAL}
    fi 

    if ${PHASEOUT}; then
        FINALPHASED=round_1_phased.vcf
        run_phase_vcf ${FINAL} ${FINALPHASED}
        FINAL=${FINALPHASED}
    else
	mv ${FINAL} ${FINAL}_tmp
	whatshap unphase ${FINAL}_tmp > ${FINAL}
	rm -f ${FINAL}_tmp
    fi
fi

###
# Cleanup
echo ""
echo "++ All done." 
if ${DELETE}; then
  files=$(ls *.hdf *.bam *.bai *.gz *.tbi)
  echo ""
  echo "- Deleting intermediate files:"
  echo "${files}"
  rm ${files}
fi
echo "- Final VCF written to ${OUTPUT}/${FINAL}"
