import os,sys
from Bio.PDB import PDBParser
from Bio.PDB import PDBIO
import Bio.PDB
from optparse import OptionParser
from os import listdir
from os.path import isfile, join
import tempfile

#Create a temporary folder in the user directory - return it so that we can do whatever with it
def create_temp_folder():
	dirpath = tempfile.mkdtemp()
	return dirpath

#Run PSA to get the 
def runPSA(code):
	cwd = os.getcwd()
	dobavka = ""
	if (sys.platform=='darwin'):
		dobavka = "_mac"
	os.system(local_path+"/aux/psa"+dobavka+" -t "+mypath+"/temp/"+code+".pdb > "+mypath+"/temp/temp"+code+".txt")

#Single file mode - AB
def constrain(options,code):
	#forward and backward mappings	
	mapping = dict()
	mapping["fwd"] = dict()
	mapping["bwd"] = dict()
	p = PDBParser()
	chains_to_choose = []
	constr_chain = 'A'
	if (code=='AB'):
		pdb_file = options.file_ab
		chains = options.chains_ab
	else:    
		pdb_file = options.file_ag
		chains = options.chains_ag
		constr_chain = 'B'

        structure = p.get_structure("input", pdb_file)
        
        for chain in chains:
                chains_to_choose.append(chain)
        i = 1
	for model in structure:
		for chain in model:
			if chain.id not in chains_to_choose:
				#What if there already is a chain A and we are not interested in it? 			
				if chain.id==constr_chain:
					chain.id = '0'			
				continue
			else:
				for residue in chain:
					id = residue.id
					if id[0] == ' ':
						mapping['fwd'][i] = str(chain.id)+"_"+str(id[1])+str(id[2]).replace(" ","")
                        			mapping['bwd'][str(chain.id)+"_"+str(id[1])+str(id[2]).replace(" ","")] = i
						residue.id = (' ', i, ' ')
						i += 1
						
				chain.id = constr_chain
	
        Select = Bio.PDB.Select
        class ConstrSelect(Select):
            def accept_chain(self, chain):
                #print dir(residue)
                
                if chain.id ==constr_chain:
                    return 1
                else:
                    return 0
	
	w = PDBIO()
        w.set_structure(structure)
        w.save(mypath+"/TMP.pdb",ConstrSelect())
        #Remove the HETATM and TER lines
	f_tmp = open(mypath+"/TMP.pdb", 'r')
	f_out = open(mypath+'/temp/'+code+'.pdb', 'w')
	for line in f_tmp.readlines():
		
		if line[0:3]!="TER" and line[0:6]!="HETATM":
			f_out.write(line)
	f_tmp.close()
	f_out.close()	
	os.remove(mypath+"/TMP.pdb")
	return mapping

#Given the renamed map, translate the paratope information
def translatePara(para_file,ab_map):
	
	fr = open(para_file,'r')
	fw = open(mypath+'/temp/paratope.txt','w')
	constraint = []
	for line in fr.readlines():
		line = line.strip().split(" ")
		i=line[0]
		c=line[1]
		constraint.append(c+"_"+i)	
	for entry in ab_map['bwd']:
		if entry not in constraint:
			fw.write(str(ab_map['bwd'][entry])+"\n")
	
	fr.close()
	fw.close()

#Create the paratope file for zdock
def translateParaZDOCK(para_file,ab_map):
	
	fr = open(para_file,'r')
	fw = open(mypath+'/temp/paratope_zdock.txt','w')
	constraint = []
	for line in fr.readlines():
		line = line.strip().split(" ")
		i=line[0]
		c=line[1]
		constraint.append(c+"_"+i)	
	for entry in ab_map['bwd']:
		if entry not in constraint:
			fw.write(str(ab_map['bwd'][entry])+" A\n")
	
	fr.close()
	fw.close()

#Create the paratope file for rescoring...
def translateParaRESCORING(para_file,ab_map):
	
	fr = open(para_file,'r')
	fw = open(mypath+'/temp/paratope_rescoring.txt','w')
	constraint = []
	for line in fr.readlines():
		line = line.strip().split(" ")
		i=line[0]
		c=line[1]
		constraint.append(c+"_"+i)	
	for entry in ab_map['bwd']:
		if entry in constraint:
			fw.write(str(ab_map['bwd'][entry])+" A\n")
	
	fr.close()
	fw.close()

#Translate the internal-standardized representation of the file into the original chain+identifier info
def translateResultsOriginal(ag_map):
	cwd = os.getcwd()
	path = mypath+"/temp/temp_out"
	files = [ f for f in listdir(path) if isfile(join(path,f)) ]
	for f in files:
		if '.txt' in f:
			fr = open(join(path,f),'r')
			fw = open(mypath+'/temp/out/original_'+f,'w')
			constraint = []
			for line in fr.readlines():
				line = line.strip()
				constraint.append(int(line))
				
			for entry in ag_map['fwd']:
				if entry not in constraint:
					to_write = str(ag_map['fwd'][int(entry)])
					to_write = to_write.split("_")
					fw.write(to_write[1]+" "+to_write[0]+"\n")
			fr.close()
			fw.close()
	
#Write out the predictions for use with docking
def translateResults(ag_map):
	cwd = os.getcwd()
	path = mypath+"/temp/temp_out"
	files = [ f for f in listdir(path) if isfile(join(path,f)) ]
	for f in files:
		if '.txt' in f:
			fr = open(join(path,f),'r')
			fw = open(mypath+'/temp/out/'+f,'w')
			constraint = []
			for line in fr.readlines():
				line = line.strip()
				constraint.append(int(line))
				
			for entry in ag_map['fwd']:
				if entry not in constraint:
					fw.write(str(entry)+" B\n")
					
			fr.close()
			fw.close()

#Check if file exists and where it is. Return a full path to it
def check_file(filename):
	#See if it is in the current directory
	cwd = os.getcwd()
	try:
		with open(cwd+"/"+filename,'r'):
			return cwd+"/"+filename
	except IOError:
		#Not in the current directory, check if it is an absolute path then
		try:
			with open(filename,'r'):
				return filename
		except IOError:
			print("File ",filename," does not exist.")
			quit()

#MAIN#

#Graph cutoff - the allowed difference between intra-atomic pairs
graph_cut = str(1.0)
#Percentage cutoff overlap at which two epitopes are considered as the same (0.5=50%)
perc_cut = str(0.3)
#Parse the options
usage = "USAGE: python EpiPred.py --abf ABFILE --abc ABCHAINS --agf AGFILE --agc AGCHAINS --epitopes EPIS --jobid OUTNAME "
parser = OptionParser(usage=usage)

#Single file mode
parser.add_option("--abf",help="Antibody file location", dest="file_ab")
parser.add_option("--abc",help="Antibody chains to constrain, e.g. -c ABCD", dest="chains_ab")
parser.add_option("--agf",help="Antigen file location", dest="file_ag")
parser.add_option("--agc",help="Antigen chains to constrain, e.g. -c ABCD", dest="chains_ag")
#Server version requires the full path for the output folder
parser.add_option("--jobid",help="Job id for this process", dest="job_id")
parser.add_option("--epitopes",help="Number of epitopes to produce", dest="num_epis")
parser.add_option("--para_ext",help="Paratope extension", dest="para_ext",default="0.0")
#Parameters that might be changed for different antigens.
#Parameters pertaining to re-scoring
parser.add_option("--original_decoys",help="Number of top original decoys to rescore", dest="original",default="200")
parser.add_option("--rescored_decoys",help="Number of decoys to return", dest="to_rescore",default="30")
#Parameters used for surface sampling
parser.add_option("--neighbor_depth",help="How many times repeat the neighborhood extension", dest="depth",default="2")
parser.add_option("--neighbor_distance",help="Distance at which neighboring atoms are considered", dest="neigh_distance",default="4.5")

(options, args) = parser.parse_args()
cwd = os.getcwd()

#Keep track of all the temp folders so that we can kill them at the end
temp_folders = []

original_decoys = int(options.original)
to_rescore = int(options.to_rescore)
mypath = create_temp_folder()
paratope_folder = create_temp_folder()
user_path = os.getcwd()
out_temp = create_temp_folder()

temp_folders.append(mypath)
temp_folders.append(paratope_folder)
temp_folders.append(out_temp)

os.system("mkdir "+mypath+"/temp")
os.system("mkdir "+mypath+"/temp/out")
os.system("mkdir "+mypath+"/temp/temp_out")
local_path = os.path.dirname(os.path.realpath(__file__))

print("Predicting epitope with nd=",options.neigh_distance,", d=",options.depth)

try:

	if (options.file_ab and options.file_ag and options.chains_ab and options.chains_ag and options.num_epis):
		
		#Create the paratope file
		#Check the validity of the submitted files
		options.file_ab = check_file(options.file_ab)	
		options.file_ag = check_file(options.file_ag)	
		
		if not os.path.isdir(options.job_id):
			os.system("mkdir "+options.job_id)
		
		#Write the Status - running EpitopePrediction [1/3]
		s_file = open(options.job_id+"/status.txt",'w')
		s_file.write("Predicting the epitope, step 1 of 2")
		s_file.close()
		os.system("python "+local_path+"/Framer.py --f "+options.file_ab+" --c "+options.chains_ab+" --o "+paratope_folder+" --d chothia")
		para_file = paratope_folder+"/paratope.txt"	
		ab_map = constrain(options,"AB")
		runPSA("AB")
	        ag_map = constrain(options,"AG")
		runPSA("AG")
		#translate the para file
		translatePara(para_file,ab_map)
		translateParaZDOCK(para_file,ab_map)
		translateParaRESCORING(para_file,ab_map)
		#Run the epitope predictor in Java
		#Arguments passed to the java program: results folder; graph algorithm cutoff, percentage overlap - We use three epitope predictions
	        os.system("java -jar "+local_path+"/EpiPred.jar "+mypath+"/temp "+graph_cut+" "+perc_cut+" "+mypath+"/temp/paratope.txt 3 "+options.para_ext+" "+local_path+" "+str(options.depth)+" "+str(options.neigh_distance))
	       
		#The epitopes are given in standardized files format, map them back to their original residue ids etc.	
		translateResults(ag_map)
		translateResultsOriginal(ag_map)
		os.system("cp -r "+mypath+"/temp/* "+out_temp)
		os.system("rm -r "+mypath)

	else:
		print("Not enough input arguments supplied")
		print(usage)

	zdock_temp = create_temp_folder()

	temp_folders.append(zdock_temp)

	#Write the Status - running ZDOCK [2/3]
	s_file = open(options.job_id+"/status.txt",'w')
	s_file.write("Docking, step 2 of 2")
	s_file.close()
	print("ZDOCK files will be written to ",zdock_temp)

	os.system(local_path+"/ZDOCK/zdock.sh "+out_temp+"/AB.pdb "+out_temp+"/AG.pdb "+zdock_temp+" "+out_temp+" "+out_temp+"/paratope_zdock.txt "+out_temp+"/paratope_rescoring.txt "+options.job_id+" "+str(to_rescore)+" "+str(original_decoys))

	print("Removing all the temp files, almost done...")

	#Delete all the temp folders together with their contents
	for temp_f in temp_folders:
		#os.system("rm -rf "+temp_f)
		print(temp_f)

	#Write the Status - running ZDOCK [2/3]
	s_file = open(options.job_id+"/status.txt",'w')
	s_file.write("Done")
	s_file.close()
except:#Catches all errors that might have arisen from executing this file
	s_file = open(options.job_id+"/status.txt",'w')
	s_file.write("Crashed")
	s_file.close()
