#!/usr/bin/env python

import os,sys,random,json
from Bio.PDB import PDBParser
from Bio.PDB import PDBIO
import Bio.PDB
from Bio.PDB.Polypeptide import is_aa, three_to_one
from optparse import OptionParser
from os import listdir
from os.path import isfile, join

try: # Import extra dependencies for json output. Allow back compatibility if not being used within sabpred
    from anarci import number, scheme_names
    from ABDB.AB_Utils.region_definitions import Accept, annotate_regions
    print("json available!")
    json_available = True      
except ImportError:
    print("json unavailable :-(")
    json_available = False

def runPSA(code,the_rand):
    cwd = os.getcwd()
    dobavka = ""
    if (sys.platform=='darwin'):
        dobavka = "_mac"
    print(mypath+"/aux/psa"+dobavka+" -t "+mypath+"/temp"+the_rand+"/"+code+".pdb > "+mypath+"/temp"+the_rand+"/temp"+code+".txt")
    os.system(mypath+"/aux/psa"+dobavka+" -t "+mypath+"/temp"+the_rand+"/"+code+".pdb > "+mypath+"/temp"+the_rand+"/temp"+code+".txt")

#Single file mode - AB
def constrain(options,code,the_rand):
    #forward and backward mappings  
    mapping = dict()
    mapping["fwd"] = dict()
    mapping["bwd"] = dict()
    p = PDBParser()
    chains_to_choose = []
    if (code=='AB'):
        pdb_file = options.file_ab
        chains = options.chains_ab
    else:    
        pdb_file = options.file_ag
        chains = options.chains_ag

    structure = p.get_structure("input", pdb_file)
        
    for chain in chains:
        chains_to_choose.append(chain)
    i = 1
    temp_chain = 1
    for model in structure:
        for chain in model:
            if chain.id not in chains_to_choose:
                #What if there already is a chain A and we are not interested in it?            
                #if chain.id=='A':
                chain.id = temp_chain
                temp_chain += 1
                #continue
            else:
                temp = 100000000
                starti = i
                for residue in chain:
                    id = residue.id
                    if id[0] == ' ':
                        residue.id = (' ', temp, ' ')
                        mapping['fwd'][i] = str(chain.id)+"_"+str(id[1])+"_"+str(id[2]).strip()
                        mapping['bwd'][str(chain.id)+"_"+str(id[1])+"_"+str(id[2]).strip()] = i
                        temp += 1
                        i += 1

                i = starti
                for residue in chain:
                    id = residue.id
                    if id[0] == ' ':
                        #mapping['fwd'][i] = str(chain.id)+"_"+str(id[1])+"_"+str(id[2]).strip()
                        #mapping['bwd'][str(chain.id)+"_"+str(id[1])+"_"+str(id[2]).strip()] = i
                        residue.id = (' ', i, ' ')
                        i += 1
                        
                #chain.id = 'A'
    
    Select = Bio.PDB.Select
    class ConstrSelect(Select):
        def accept_chain(self, chain):
            #print dir(residue)
            
            #if chain.id == 'A':
            if type(chain.id) is not int:
                return 1
            else:
                return 0
            
    w = PDBIO()
    w.set_structure(structure)
    w.save("TMP"+the_rand+".pdb",ConstrSelect())
    #Remove the HETATM and TER lines
    f_tmp = open("TMP"+the_rand+".pdb", 'r')
    f_out = open(mypath+"/temp"+the_rand+"/"+code+'.pdb', 'w')
    for line in f_tmp.readlines():
        if line[0:3]!="TER" and line[0:6]!="HETATM":
            if line.startswith("ATOM"):
                f_out.write(line[:21] + "A" + line[22:])
            else:
                f_out.write(line)
    f_tmp.close()
    f_out.close()   
    os.remove("TMP"+the_rand+".pdb")
    return mapping

#Returns true if the PDB has all the chains provided for the given pdb
def check_chains(options,code):
    #Single file mode - AB
    p = PDBParser()
    if (code=='AB'):
        pdb_file = options.file_ab
        chains = options.chains_ab
    else:    
        pdb_file = options.file_ag
        chains = options.chains_ag

    structure = p.get_structure("input", pdb_file)
    chains_to_choose = []
    for chain in chains:
                chains_to_choose.append(chain)
    
    chains_in_file = []
    for model in structure:
        for chain in model:
            chains_in_file.append(chain.id)
    for chain in chains_to_choose:
        if chain not in chains_in_file:
            return False
    return True 

#Single file mode - AB
def constrainBACK(pdb_file,pdb_file_out,the_map):
    #forward and backward mappings  
    f = open(pdb_file,'r')
    f_out = open(pdb_file_out,'w')
    output=""
    for line in f.readlines():
        line = line.strip()
        if "ATOM" in line:
            chain = line[21]
            res_num = int(line[23:28].strip())
            mapped = the_map[res_num]
            new_chain = mapped.split("_")[0]
            new_resnum = mapped.split("_")[1]
            while len(new_resnum)<3:
                new_resnum=" "+new_resnum
            if len(mapped.split("_")[2])==0:
                new_line = line[0:21]+new_chain+line[22:23]+new_resnum+line[26:len(line)]
            else:
                new_line = line[0:21]+new_chain+line[22:23]+new_resnum+mapped.split("_")[2]+line[27:len(line)]          
            output+=new_line+"\n"
    f_out.write(output)
    f_out.close()
    f.close()

def fileExists(filename):
    try:
        with open(filename): pass
    except IOError:
        return False
    return True

def checkAllIsOk(mypath,the_rand,options):
    if not fileExists(mypath+"/AB.pdb"):
        print(Exception('Something wrong with the antibody file'))
        crash(options)
    else:
        if not check_chains(options,"AB"):
            print(Exception('Not all the antibody chains could be found in the antibody file'))
            crash(options)
    if not fileExists(mypath+"/AG.pdb"):
        print(Exception('Something wrong with the antigen file'))
        crash(options)
    else:
        statinfo = os.stat(mypath+"/AG.pdb")
        if statinfo.st_size == 0:
            print(Exception('Something wrong with the antigen file'))
            crash(options)
        else:
            if not check_chains(options,"AG"):
                print(Exception('Not all the antigen chains could be found in the antigen file'))
                crash(options)


def crash(options):
    s_file = open(options.job_id+"/output_folder/status.txt",'w')
    s_file.write("Crashed")
    s_file.close()
    quit()


class IpatchJSONWriter(object):
    '''
    Class to extract the sequences of antibody chains, number them, identify CDRs and produce a JSON object that includes
    the aligned bfactor.
    '''
    def __init__( self, file_in, file_out, chains, overwrite=False ):
        '''
        '''
        structure = self._read_structure( file_in )
        chain_details = {}
        for chain in chains:
            details = self._do_extract( structure[0][chain] , 'imgt' )
            if not details:
                print('Chain %s was not recognised as an antibody chain. All i-Patch scores have been set to zero.'%chain) 
                continue
            chain_details[chain] = details

        JSON = self._compile_json( chain_details )
        self._write_json( JSON , file_out )
        # Overwrite the output limitting scored residues to the Fv only
        if overwrite:
            self._output_structure(structure, file_in)
        
    def _read_structure(self,fin):
        '''
        '''
        return PDBParser(QUIET=True).get_structure(fin, fin)

    def _output_structure( self, structure, fin ):
        '''
        '''
        io = PDBIO()
        io.set_structure( structure )
        with open( fin, 'w' ) as out:
            io.save( out )


    def _extract_details(self, chain):
        '''
        '''
        r2a = [ (r, three_to_one( r.get_resname() )) for r in chain if is_aa(r, standard=False ) ]
        seq = ''.join( [_[1] for _ in r2a] )        
        bfactors = [ r['CA'].get_bfactor() if 'CA' in r else 0 for r, a in r2a ]
        return  r2a, seq, bfactors

    def _do_extract(self, chain,scheme):
        '''
        '''
        r2a, sequence, bfactors  = self._extract_details(chain)
        numbering, ctype = number( sequence, scheme )   
        if ctype and ctype in 'HL':
            numbering = [ (n, a) for n, a in numbering if a != '-' ] # Remove gaps that ANARCI can put in for imgt.
            numbered_seq = ''.join( [ a[1] for a in numbering] )
            offset = sequence.index( numbered_seq ) # Find the start point 
            bfactors = self._restrict_to_fv(r2a, bfactors, offset, offset+len(numbering) )
            return r2a, sequence, bfactors, numbering, ctype, offset
        else: # Reset non-antibody chains to have a 0 bfactor
            for r, _ in r2a:
                for a in r:
                    a.set_bfactor(0)


    def _get_all_numberings(self, sequence):
        '''
        '''
        all_numberings = {}
        for scheme in scheme_names:
            if len(scheme) == 1: continue
            numbering, ctype = number( sequence, scheme )   
            numbering = [ (n, a) for n, a in numbering if a != '-' ] # Remove gaps that ANARCI can put in for imgt.
            offset = sequence.index( ''.join( [ a[1] for a in numbering] ) ) # Find the start point 
            all_numberings[scheme+'_numbering'] = ['-']*offset + [ ('%d%s'%r[0]).strip() for r in numbering ] + ['-']*( max( 0, len(sequence) - len(numbering) - offset) )
        return all_numberings

    def _compile_json(self, chain_details):
        '''
        For each chain 
            type
            sequence (string)
            bfactor
            numbering from pdb
            antibody numbering (gapped n and c to align) 
            scheme of antibody numbering
            ranges of the cdrs relative to the sequence for each definition
        '''
        definitions = ['imgt','chothia','kabat','north', 'contact']
        JSON = {}
        maxips = 0
        for chain in chain_details:
            JSON[chain] = {}
            JSON[chain]['chain_type'] = chain_details[chain][4]
            if JSON[chain]['chain_type'] == 'H':
                JSON['heavy'] = chain
            if JSON[chain]['chain_type'] in 'KL':
                JSON['light'] = chain
                
            JSON[chain]['sequence'] = chain_details[chain][1]  
            JSON[chain]['ipatch_score'] = chain_details[chain][2]
            JSON[chain]['pdb_numbering'] = [ ('%d%s'%r[0].id[1:]).strip() for r in chain_details[chain][0] ]
            # Get all the numbering schemes (numbering done multiple times)
            JSON[chain].update( self._get_all_numberings( JSON[chain]['sequence'] ) )
            for definition in definitions:
                JSON[chain][definition] = self._get_range( chain_details[chain][3], JSON[chain]['chain_type'], 'imgt', definition, chain_details[chain][5] )
            maxips = max( maxips, max( JSON[chain]['ipatch_score'] ) )
        if not JSON: return JSON
        JSON['max_ipatch_score'] = maxips

        # Get a ranking of the residues according to their ipatch score. 
        # Handle cases when only a heavy or only a light chain has been submitted
        if 'heavy' in JSON and 'light' in JSON:
            ranked = self._rankdata(JSON[JSON['heavy']]['ipatch_score'] + JSON[JSON['light']]['ipatch_score'])
            JSON['ipatch_ranking'] = []
            nH = len( JSON[JSON['heavy']]['ipatch_score'] )
            for index in ranked:
                if index >= nH:
                    JSON['ipatch_ranking'].append( ['light', index-nH] )
                else:
                    JSON['ipatch_ranking'].append( ['heavy', index] )
        elif 'heavy' in JSON:
            ranked = self._rankdata(JSON[JSON['heavy']]['ipatch_score']) 
            JSON['ipatch_ranking'] = []
            for index in ranked:
                JSON['ipatch_ranking'].append( ['heavy', index] )
        elif 'light' in JSON:
            ranked = self._rankdata(JSON[JSON['light']]['ipatch_score']) 
            JSON['ipatch_ranking'] = []
            for index in ranked:
                JSON['ipatch_ranking'].append( ['light', index] )

        return JSON
           
    def _get_range(self, numbering, chain, scheme, definition, offset=0):
        """
        Get the indices which correspond to CDRs for a particular definition.

        """
        regions = annotate_regions(numbering, chain, numbering_scheme=scheme, definition=definition )       
        ranges = []
        framework, cdr = True, False
        for i in range( len(numbering) ):
            reg = regions[i][2]
            if "cdr" in reg and framework:
                ranges.append(i+offset)
                framework, cdr = False, True
            elif "fw" in reg and cdr:
                ranges.append(i-1+offset)
                framework, cdr = True, False
        return ranges

    def _write_json(self, data, outputfilename):
        """
        Turn a dictionary into json and write to file.
        Write a json to an open file.

        @param data: A dictionary
        @param outputfilename: The file name to write the json to.
        """
        with open( outputfilename, 'w' ) as out:
            json.dump(data, out,sort_keys=True)

    def _rankdata(self, a):
        return sorted( list(range( len(a))), key=a.__getitem__, reverse=True )

    def _restrict_to_fv(self, residues, bfactors, start, end):
        '''
        Post i-Patch prediction so that it only gives a score for the Fv region of the antibody
        '''
        assert len(residues) == len( bfactors )
        for i in range( len(residues )):
            if i < start or i >= end:    
                for a in residues[i][0]:
                    a.set_bfactor(0)
                bfactors[i] = 0
        return bfactors            




#Parse the options
usage = "USAGE: python RunABIpatch.py --abf ABFILE --abc ABCHAINS --agf AGFILE --agc AGCHAINS --jobid JOBID "
parser = OptionParser(usage=usage)

#Single file mode
parser.add_option("--abf",help="Antibody file location", dest="file_ab")
parser.add_option("--abc",help="Antibody chains to constrain, e.g. -c ABCD", dest="chains_ab")
parser.add_option("--agf",help="Antigen file location", dest="file_ag")
parser.add_option("--agc",help="Antigen chains to constrain, e.g. -c ABCD", dest="chains_ag")
parser.add_option("--paratope",help="Paratope file", dest="para_file")
parser.add_option("--jobid",help="Job id for this process", dest="job_id")

(options, args) = parser.parse_args()
cwd = os.getcwd()
mypath = os.path.dirname(os.path.realpath(__file__))
the_rand = str(int(random.random()*1000000000))

#Check that all the files are here and that they have the necessary chains
checkAllIsOk(mypath,the_rand,options)
print("Everything appears to be in order<br>")

os.system("mkdir "+mypath+"/temp"+the_rand)
os.system("mkdir "+mypath+"/temp"+the_rand+"/out")
os.system("mkdir "+mypath+"/temp"+the_rand+"/temp_out")
os.system("mkdir "+mypath+"/output_folder")
os.system("chmod a+rwx -R "+mypath+"/temp*")
try:
    if (options.file_ab and options.file_ag and options.chains_ab and options.chains_ag and options.job_id):
        
        s_file = open(options.job_id+"/output_folder/status.txt",'w')
        s_file.write("Predicting")
        s_file.close()
        #Constrain the antibody file    
        ab_map = constrain(options,"AB",the_rand)
        #Constrain the antigen file
        ag_map = constrain(options,"AG",the_rand)
        #Find surface residues for antibody (AB) and antigen (AG) using PSA - doing it from java fails too often
        runPSA("AB",the_rand)
        runPSA("AG",the_rand)
        #Run AB i-Patch
        print("java -jar ABiPatch.jar temp"+the_rand+"/AB.pdb temp"+the_rand+"/AG.pdb temp"+the_rand+"/tempAB.txt temp"+the_rand+"/tempAG.txt output_folder")
        os.system("java -jar ABiPatch.jar temp"+the_rand+"/AB.pdb temp"+the_rand+"/AG.pdb temp"+the_rand+"/tempAB.txt temp"+the_rand+"/tempAG.txt output_folder")        
        #Remove all the temp files
        os.system("rm -r "+mypath+"/temp*")
        os.system("rm mapping* const*")
        #Translate the single chain antibody returned by AB i-Patch to it's original coordinates and chains
            
        constrainBACK(mypath+"/output_folder/antibody.pdb",mypath+"/output_folder/solution.pdb",ab_map["fwd"])
        os.system("cp "+mypath+"/output_folder/solution.pdb "+mypath+"/output_folder/antibody.pdb")

        # Write out a jsonp file that can be used on the front end for sequence visualisation.
        if json_available:
            try:
                IpatchJSONWriter( mypath+"/output_folder/solution.pdb", mypath+"/output_folder/solution.jsonp", options.chains_ab, overwrite=True ) 
            except AssertionError as e: # Handle the errors. Currently quiet as in testing. Should not be fatal. 
                pass
            except IOError as e:
                pass
            except Exception as e: # Temporary handling of any unforseen errors, allows ipatch to continue 
                pass

        #Remove the single chain antibody file
        os.system("rm "+mypath+"/output_folder/antibody.pdb")
        s_file = open(options.job_id+"/output_folder/status.txt",'w')
        s_file.write("Done")
    else:
        print("Not enough input arguments supplied")
        print(usage)

except:#Catches all errors that might have arisen from executing this file
    s_file = open(options.job_id+"/output_folder/status.txt",'w')
    s_file.write("Crashed")
    s_file.close()


