# Functions to score decoys based on which predicted contacts they satisfy.

import os
import numpy as np
import scipy.stats as ss

from Bio import PDB

CONTACT_DISTANCE = 8

def split_res_number_ins_code(resid):
    """
    Takes a residue id (int or str) and splits it into two parts: the actual residue number and
    the insertion code
    """

    try:
        # If this works then the residue ID is simply a number
        # Insertion code in this case is blank
        resno = int(resid)
        resins = " "

        return (resno, resins)
    except:
        # If the above failed, then there are letters in the input
        try:
            split_alphanum = re.split("(\d+)", resid)
        except:
            embed()
        split_alphanum = [i for i in split_alphanum if i != ""]

        if len(split_alphanum) == 2:
            if split_alphanum[0].isdigit():
                resno = int(split_alphanum[0])
                resins = split_alphanum[1]

                return (resno, resins)
            else:
                print("Error: incorrect format for residue ID")
                return (None, None)
        else:
            print("Error: incorrect format for residue ID")
            return (None, None)


def in_loop(resno, start, end):
    """Checks whether a given residue number lies within a specified loop region."""

    if resno[0] > start[0] and resno[0] < end[0]:
        return True
    elif resno[0] < start[0] or resno[0] > end[0]:
        return False
    else:
        alphabet = [" ", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
                    "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
        if resno[0] == start[0]:
            i = alphabet.index(resno[1])
            starti = alphabet.index(start[1])
            if starti > i:
                return False
            else:
                return True
        if resno[0] == end[0]:
            i = alphabet.index(resno[1])
            endi = alphabet.index(end[1])
            if endi < i:
                return False
            else:
                return True


def rank_using_contacts(contact_output, contact_file, decoy_file, pdb_file, chain, start, end):

    if os.path.exists(contact_file):
        contact_data = open(contact_file).readlines()
    else:
        print("Contact file not found. Ranking without contacts instead")
        return False

    pred_contacts = {}
    pred_contact_list = []
    for l in contact_data:
        res1, res2, score = l.split(",")
        res1 = res1.strip()
        res2 = res2.strip()
        pred_contacts["Contact_%s_%s" %(res1, res2)] = (res1, res2, float(score))
        pred_contact_list.append("Contact_%s_%s" %(res1, res2))

    # Don't bother if there aren't any contacts in the file
    if pred_contacts == {}:
        return False

    # Calculate distances for all predicted contacts in all decoy structures
    input_struc  = PDB.PDBParser(QUIET=True).get_structure("input", pdb_file)
    decoy_strucs = PDB.PDBParser(QUIET=True).get_structure("decoys", decoy_file)
    decoy_ids = [l.split()[1] for l in open(decoy_file).readlines() if l.startswith("MODEL")]
    decoy_data = {}

    for decoy_no in range(len(decoy_strucs)):
        decoy_id = decoy_ids[decoy_no]
        decoy_struc = decoy_strucs[decoy_no]
        decoy_data[decoy_id] = {}

        for cont in pred_contact_list:
            resno1 = split_res_number_ins_code(pred_contacts[cont][0])
            resno2 = split_res_number_ins_code(pred_contacts[cont][1])

            try:
                if in_loop(resno1, split_res_number_ins_code(start), split_res_number_ins_code(end)):
                    res1 = decoy_struc[chain][(" ", resno1[0], resno1[1])]
                else:
                    res1 = input_struc[0][chain][(" ", resno1[0], resno1[1])]

                if res1.resname == "GLY":
                    atom1 = res1["CA"]
                else:
                    atom1 = res1["CB"]

                if in_loop(resno2, split_res_number_ins_code(start), split_res_number_ins_code(end)):
                    res2 = decoy_struc[chain][(" ", resno2[0], resno2[1])]
                else:
                    res2 = input_struc[0][chain][(" ", resno2[0], resno2[1])]

                if res2.resname == "GLY":
                    atom2 = res2["CA"]
                else:
                    atom2 = res2["CB"]

                decoy_data[decoy_id][cont] = atom1 - atom2

            except:
                continue

    # Establish which contacts are satisfied at least once
    # Contacts that are never satisfied are excluded from the score
    contacts_satisfied = []
    for cont in pred_contact_list:
        distances = [decoy_data[decoy][cont] for decoy in decoy_data if cont in decoy_data[decoy]]
        if distances == []:
            print("No distances were calculated for contact between residues %s and %s: is one of these residues missing?" %(cont.split("_")[1], cont.split("_")[2]))
        elif min(distances) <= CONTACT_DISTANCE:
            contacts_satisfied.append(cont)
        else:
            print("Contact between residues %s and %s is never satisfied and will be ignored." %(cont.split("_")[1], cont.split("_")[2]))

    # Calculate contact scores and write output file
    out = open(contact_output, "w")
    out.write("Decoy, NoContactsSatisfied, %s, Score\n" %(", ".join(pred_contact_list)))

    for decoy in decoy_data:
        l = decoy + ", "
        no_satisfied = len([i for i in list(decoy_data[decoy].values()) if i <= 8])
        l += "%d, " %(no_satisfied)

        score = 0
        for cont in pred_contact_list:
            if cont in decoy_data[decoy]:
                dist = decoy_data[decoy][cont]
                l += "%f, " %(dist)
            else:
                l += "-, "
                continue

            contact_score = pred_contacts[cont][2]

            if cont not in contacts_satisfied:
                continue

            if dist <= 8:
                score += -contact_score
            else:
                score += -contact_score * np.exp(-(dist-8)**2) + contact_score*((dist-8)/dist)

        l += "%f\n" %(score)

        out.write(l)

    out.close()

    return True


def select_top_500_with_contacts(RAPDF_file, contact_file):
    """Uses the scores from the RAPDF as well as the contact score to select the top 500 decoys."""

    decoy_data = {}

    contact_data = open(contact_file).readlines()
    for l in contact_data[1:]:
        decoy_id = l.split(", ")[0]
        score = float(l.split(", ")[-1])
        decoy_data[decoy_id] = {"contact": score}

    RAPDF_data = open(RAPDF_file).readlines()
    for l in RAPDF_data[1:]:
        decoy_id = l.split(", ")[0]
        score = float(l.split(", ")[-1])
        decoy_data[decoy_id]["RAPDF"] = score

    sorted_by_contacts = sorted(list(decoy_data.keys()), key=lambda x: decoy_data[x]["contact"])
    ranks = ss.rankdata([decoy_data[decoy]["contact"] for decoy in sorted_by_contacts], method="min")
    for i in range(len(sorted_by_contacts)):
        decoy_data[sorted_by_contacts[i]]["contact_rank"] = ranks[i]

    sorted_by_RAPDF = sorted(list(decoy_data.keys()), key=lambda x: decoy_data[x]["RAPDF"])
    ranks = ss.rankdata([decoy_data[decoy]["RAPDF"] for decoy in sorted_by_RAPDF], method="min")
    for i in range(len(sorted_by_RAPDF)):
        decoy_data[sorted_by_RAPDF[i]]["RAPDF_rank"] = ranks[i]

    for decoy in decoy_data:
        decoy_data[decoy]["combined_rank"] = decoy_data[decoy]["contact_rank"] + decoy_data[decoy]["RAPDF_rank"]

    sorted_combined = sorted(list(decoy_data.keys()), key=lambda x: (decoy_data[x]["combined_rank"], decoy_data[x]["RAPDF_rank"]))

    top500 = sorted_combined[:500]

    return top500
