import os
import shutil
import subprocess

try:
    from modeller import *
    from modeller.scripts import complete_pdb
    from modeller import soap_loop
    modeller_installed = True
except:
    modeller_installed = False
from sphinx.sidechains.bears import run_bears

def which(name):
    """ Used to search for SCWRL """
    for path in os.getenv("PATH").split(os.path.pathsep):
        full_path = os.path.join(path, name)
        if os.path.exists(full_path):
            return full_path
    return ""

def refine_scwrl(**kwargs):
    """ Wrapper to run SCWRL on the output structure """

    output_file, SCWRL_file, start, end, chain, pdb_file = \
            kwargs['output_file'],  kwargs['SCWRL_file'], kwargs['start'], kwargs['end'], kwargs['chain'], kwargs['pdb_file']

    SCWRL_exec = which("Scwrl4")
    os.system("%s -i %s -s %s -o %s >/dev/null" %(SCWRL_exec, output_file, SCWRL_file, output_file))
    # Correct for any mistakes (missing residues etc.) caused by SCWRL in the rest of the file
    correct_SCWRL(output_file, start, end, chain, pdb_file)

def refine_bears(**kwargs):
    """ Wrapper to run the BEARS method on the output structure """
    output_file = kwargs['output_file']
    run_bears(output_file, output_file)

def code_convert(aa):
    
    three_to_one = {"ALA": "A", "GLU": "E", "GLN": "Q", "ASP": "D", "ASN": "N", "LEU": "L", 
                    "GLY": "G", "LYS": "K", "SER": "S", "VAL": "V", "ARG": "R", "THR": "T", 
                    "PRO": "P", "ILE": "I", "MET": "M", "PHE": "F", "TYR": "Y", "CYS": "C", 
                    "TRP": "W", "HIS": "H"}
                    
    one_to_three = {"A": "ALA", "E": "GLU", "Q": "GLN", "D": "ASP", "N": "ASN", "L": "LEU",
                    "G": "GLY", "K": "LYS", "S": "SER", "V": "VAL", "R": "ARG", "T": "THR",
                    "P": "PRO", "I": "ILE", "M": "MET", "F": "PHE", "Y": "TYR", "C": "CYS",
                    "W": "TRP", "H": "HIS"}
                    
    if len(aa) == 3:

        if aa in three_to_one:
            return three_to_one[aa]
        else:
            return "X"
    else:
        if aa in one_to_three:
            return one_to_three[aa]
        else:
            return "XXX"

def get_before_after_loop_for_SCWRL(start, end, loop_chain, pdb_path, loop_seq):
    
    SCWRL_seq = ""
    added = []

    #Read in native structure
    native_structure = {}

    #native_file = open(path_to_pdb(target[:4])).readlines()
    native_file = open(pdb_path).readlines()

    atoms = [line for line in native_file if line.startswith("ATOM")]

    chain_order = []
    for line in atoms:
        chain = line[21]
        atom_type = line[12:16]
        res_no = line[22:27]

        if chain not in chain_order:
            chain_order.append(chain)

        if chain in native_structure:
            if res_no in native_structure[chain]:
                native_structure[chain][res_no].append(atom_type)
            else:
                native_structure[chain][res_no] = [atom_type]
        else:
            native_structure[chain] = {res_no: [atom_type]}

    to_skip = []
    for chain in chain_order:
        for res_no in native_structure[chain]:
            atoms_present = native_structure[chain][res_no]
            if " N  " in atoms_present and " CA " in atoms_present and " C  " in atoms_present and " O  " in atoms_present:
                continue
            else:
                to_skip.append(chain + res_no)

    #Get residues before the loop in the pdb file         
    first = [line for line in atoms if line[21]==loop_chain and int(line[22:26]) == int(start)-1][-1]
    first_index = native_file.index(first)
    pre_loop = [line for line in native_file[:first_index] if line[21:27] not in to_skip]
    pre_loop_string = "".join(pre_loop)
    
    for line in pre_loop:
        if line.startswith("ATOM"):
            chain = line[21]
            no = line[22:27]
            atoms_present = native_structure[chain][no]
            if " N  " in atoms_present and " CA " in atoms_present and " C  " in atoms_present and " O  " in atoms_present:
                if (chain+no) not in added:
                    SCWRL_seq += code_convert(line[17:20]).lower()
                    added.append(chain+no)
                    
    SCWRL_seq += loop_seq.upper()

    #Get residues after the loop in the pdb file
    last = [line for line in atoms if line[12:16] == " N  " and line[21]==loop_chain and int(line[22:26]) == int(end)+1]
    last_index = native_file.index(last[0])
    after_loop = [line for line in native_file[last_index:] if line[21:27] not in to_skip]
    after_loop_string = "".join(after_loop)
    
    for line in after_loop:
        if line.startswith("ATOM"):
            chain = line[21]
            no = line[22:27]
            atoms_present = native_structure[chain][no]
            if " N  " in atoms_present and " CA " in atoms_present and " C  " in atoms_present and " O  " in atoms_present:
                if (chain+no) not in added:
                    SCWRL_seq += code_convert(line[17:20]).lower()
                    added.append(chain+no)

    return pre_loop_string, after_loop_string, SCWRL_seq

def correct_SCWRL(SCWRL_output_file, start, end, ch, pdb_path):

    #Read in native structure
    native_file = open(pdb_path).readlines()
    atoms = [line for line in native_file if line.startswith("ATOM")]

    #Get residues before the loop in the pdb file         
    first = [line for line in atoms if line[21]==ch and int(line[22:26]) == int(start)-1][-1]
    first_index = native_file.index(first)
    pre_loop = native_file[:first_index]
    pre_loop_string = "".join(pre_loop)

    #Get residues after the loop in the pdb file
    last = [line for line in atoms if line[12:16] == " N  " and line[21]==ch and int(line[22:26]) == int(end)+1]
    last_index = native_file.index(last[0])
    after_loop = native_file[last_index:]
    after_loop_string = "".join(after_loop)

    #Get the decoy structure with sidechains from SCWRL output
    decoy_file = open(SCWRL_output_file).readlines()
    decoy_atoms = [line for line in decoy_file if line.startswith("ATOM")]

    first = [line for line in decoy_atoms if line[21]==ch and int(line[22:26]) == int(start)-1][-1]
    first_index = decoy_file.index(first)
    last = [line for line in decoy_atoms if line[12:16] == " N  " and line[21]==ch and int(line[22:26]) == int(end)+1][0]
    last_index = decoy_file.index(last)

    decoy = decoy_file[first_index:last_index]
    decoy_string = "".join(decoy)

    with open(SCWRL_output_file, "w") as f:
        f.write(pre_loop_string + decoy_string + after_loop_string)

    return
    
def insert_decoy(chain, decoy, pre_loop_string, after_loop_string, output_file):

    #Get decoy structure
    atoms = [line for line in decoy if line.startswith("ATOM")]
    decoy_string = ""
    for line in atoms:
        decoy_string += line[:21] + chain + line[22:]  #Correct chain codes

    #Write file
    f = open(output_file, "w")
    f.write(pre_loop_string + decoy_string + after_loop_string)
    f.close()

    return

def select_top_500(ranking_file):
    data = open(ranking_file).readlines()
    
    decoys = []
    scores = []
    
    for line in data[1:]:
        decoy, score = line.rstrip().split(", ")
        decoys.append(decoy)
        scores.append(float(score))
        
    sorted_decoys = [x for (y,x) in sorted(zip(scores, decoys))]
    
    top500 = sorted_decoys[:500]
    
    return top500
    
def write_rank_top500(chain, start, end, top500, decoy_file, pdb_file, scriptpath, loop_seq, dbdir, ranking_method="soaploop", sidechain_method="pears"):
    to_write = []
    top500_file = os.path.join(os.path.dirname(decoy_file), "top_500_decoy_structures.pdb")
   
    DDFIRE_DIR = os.path.join(scriptpath, "ranking", "ranking_execs", "dDFIRE")
    RWPLUS_DIR = os.path.join(scriptpath, "ranking", "ranking_execs", "RWplus")

    pre_loop_string_SCWRL, after_loop_string_SCWRL, SCWRL_seq = get_before_after_loop_for_SCWRL(start, end, chain, pdb_file, loop_seq)
    
    if sidechain_method == "scwrl4":
        SIDECHAIN_EXEC   = refine_scwrl
    elif sidechain_method == "pears":
        SIDECHAIN_EXEC = refine_bears
   
    SCWRL_file = os.path.join(os.path.dirname(decoy_file), "SCWRL_sequence.txt")
    with open(SCWRL_file, "w") as f:
        f.write(SCWRL_seq)

    output_folder = os.path.join(os.path.dirname(decoy_file), "CompleteModels")
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
    os.makedirs(output_folder)
        
    top500_decoy_codes = []
    score_list = []
    
    if modeller_installed and ranking_method == "soaploop":
        # Prepare Modeller/SOAP-Loop
        log.none()
        env = environ()
        env.libs.topology.read(file="$(LIB)/top_heav.lib")
        env.libs.parameters.read(file="$(LIB)/par.lib")
        #KK: had to point to a database of soap loops
    
        #sp = soap_loop.Scorer(library=os.path.join('/data/greenheron/databases/SphinxDatabases','soap_loop.hdf5'))
        sp = soap_loop.Scorer()
   
    if ranking_method == "ddfire":
        os.chdir(DDFIRE_DIR)
    elif ranking_method == "rwplus":
        os.chdir(RWPLUS_DIR)

    counter = 0 
    data = open(decoy_file).readlines()
    decoycount = min(500, len([line for line in data if line.startswith("MODEL")]))
    for line in data:
        if line.startswith("MODEL"):
            decoy = []
            decoy_code = line.split()[1]
            decoy.append(line)
        elif line.startswith("ATOM"):
            decoy.append(line)
        elif line.startswith("ENDMDL"):
            decoy.append(line)
            if decoy_code in top500:
                to_write += decoy
                top500_decoy_codes.append(decoy_code)
                counter += 1
                print("\nDecoy %d/%d:" %(counter, decoycount))
                
                output_file = os.path.join(output_folder, "%s.pdb" %decoy_code)
                insert_decoy(chain, decoy, pre_loop_string_SCWRL, after_loop_string_SCWRL, output_file)
                    
                ## Run SCWRL or BEARS - OPIG's backbone-dependent side chain predictor
                print("Predicting sidechains...")
                SIDECHAIN_EXEC(output_file=output_file, 
                               SCWRL_file=SCWRL_file,pdb_file=pdb_file,
                               start=start,end=end,chain=chain)
                    
                if ranking_method == "soaploop":
                    # Score with SOAP-Loop
                    mdl = complete_pdb(env, output_file, transfer_res_num=True)
                    s = selection(mdl)
                    score = s.assess(sp)
                    score_list.append(score)
                elif ranking_method == "ddfire":
                    # Score with dDFIRE
                    out = subprocess.Popen([os.path.join(DDFIRE_DIR, "dDFIRE"), output_file], stdout=subprocess.PIPE)
                    result = out.communicate()
                    score = float(result[0].split()[1])
                    print("dDFIRE score: %f" %score)
                    score_list.append(score)
                elif ranking_method == "rwplus":
                    # Score with (cal)RWplus
                    out = subprocess.Popen([os.path.join(RWPLUS_DIR, "calRWplus"), output_file], stdout=subprocess.PIPE)
                    result = out.communicate()
                    score = float(result[0].split()[3])
                    print("RWplus score: %f" %score)
                    score_list.append(score)
                    
                    
    with open(top500_file, "w") as f:
        f.write("".join(to_write))
        
    SP_output = open(os.path.join(os.path.dirname(decoy_file), "ranking_final.csv"), "w")
    SP_output.write("Rank, Decoy, Score\n")
    sorted_indices = sorted(list(range(len(top500_decoy_codes))), key=lambda x: score_list[x])
    rank = 0
    for i in sorted_indices:
        rank += 1
        SP_output.write("%d, %s, %f\n" %(rank, top500_decoy_codes[i], score_list[i]))
    SP_output.close()
    return
