#!/usr/bin/env python

"""
correct_imgt_annotations.py
Description: 	Correct IMGT annotations in my local version of SAbDab.
May 23, 2017
"""
from glob import glob
from anarci import anarci
from Bio import SeqIO
import os, shutil
from ABDB import database as db

banned_list = [ '4hjg', '4hjj', '4org', '4od1', '4ocw', '4ocs', '4od3', '4odh', '4ocr', '3upc', '4unt', '3eot', '1sjv', '1h8o', '1h8s', '1fn4', '3mlt', '1ngx', '3mls', '3mlr', '3mlu','3mlv', '3dvg', '3dvn', '1oau', '2gki', '1lmk', '1nqb', '1moe', '4k3e', '4k3d', '4kro','5dt1', '4dqo', '4hqq','4isv','4jm4','4lrt','4lrn','4lsp','4lsq','4lsr','4nuj','4ob5','4osu','4pub','4x8j','4hx2','4yaq','3c08','3d0l','3drq','3lrs','3mug','3na9','3qeg','3qhz','3u1s','3u2s','3u4e','1fgv','1mie','1mj7','1rzf','1rzg','2a77','2dtm','2hff','2iqa','5e7b','5ihu','5e99','5ijv','5ilt','3bj9','5hcg','5fhx', "5fcs", "5gry", "5gru"]
entries = glob('*/')

scfv = ['4buh',
'3j8d',
'2gki',
'5j74',
'5j75',
'3umt',
'3uzq',
'5grz',
'4nik',
'5gs2',
'4f9l',
'5grw',
'4f9p',
'5jyl',
'3uze',
'1x9q',
'5fxc',
'1jp5',
'1h8o',
'5f72',
'5u0u',
'3auv',
'3etb',
'4h0g',
'4h0h',
'4h0i',
'5d9q',
'4yjz',
'1f3r',
'3fku',
'4gqp',
'5lx9',
'5aam',
'1qok',
'3uzv',
'1svz',
'5a2i',
'5jym',
'5i4f',
'2ghw',
'3uyp',
'2gjj',
'3h3b',
'5lxa',
'3ux9',
'5kov',
'5dhx',
'5aaw',
'3gkz',
'5gru',
'5iwl',
'3wbd',
'5u68',
'5gs1',
'1moe',
'1lmk',
'5dfw',
'4cau',
'3esu',
'5b3n',
'5f3j',
'3esv',
'2kh2',
'1h8s',
'5fcs',
'5a2l',
'3gm0',
'5grx',
'5gry',
'1nqb',
'5c6w',
'5a2j',
'5a2k',
'3juy',
'1dzb',
'3et9',
'5gs3',
'1h8n']

#idx = entries.index("5hcg/")

for pdb in entries:
    #if pdb.strip("/") not in db or pdb.strip("/") not in scfv:
    if pdb.strip("/") not in db:
        continue

    path = os.path.join(pdb, 'annotation', 'imgt')
    seqpath = os.path.join(pdb, "sequences", "imgt")
    
    print pdb
    annfiles = {}

    for annfile in glob(os.path.join(path, '*.ann')):
        shutil.copy(annfile, annfile+'.bak')
        annfiles[ annfile.split("_")[1] ] = annfile

    for seqfile in glob(seqpath+"/*"):
        f = open(seqfile)
        chainid = seqfile.split("_")[1]
        chtype  = seqfile.split("_")[2].replace(".fa", "").replace("V", "")

        print chtype

        numbering = None
        sequences = {}

        for a_record in SeqIO.parse(f, 'fasta'):
            sequences[a_record.description] = str(a_record.seq)
            if 'seqres|full' in a_record.description:
                try:
                    numbering = anarci([('seq', str(a_record.seq))])[0][0]
                    if len(numbering) == 1:
                        numbering = numbering[0][0]
                    elif len(numbering) > 1:
                        scfv_types = anarci([('seq', str(a_record.seq))])[1][0]
                        if "H" in scfv_types[0]['chain_type']:
                            hidx = 0
                            lidx = 1
                        else:
                            hidx = 1
                            lidx = 0

                        numbering = numbering[hidx][0] if chtype == "H" else numbering[lidx][0]
                except:
                    continue

        if not numbering:
            f = open(seqfile)
            for rec in SeqIO.parse(f, 'fasta'):
                sequences[rec.description] = str(rec.seq)
                if "structure|full" in rec.description:
                    numbering = anarci([('seq', str(rec.seq))])[0][0][0][0]
                    if len(numbering) == 1:
                        numbering = numbering[0][0]
                    elif len(numbering) > 1:
                        numbering = numbering[0][0] if chtype == "H" else numbering[1][0]
                    break

        try:
            numbered_seq = "".join([ n[1] for n in numbering if n[1] != "-" ])
            target_seq = [ sequences[d] for d in sequences if "seqres|region" in d ][0].replace("-", "")
            numbering = numbering[ numbered_seq.index(target_seq): ]
        except:
            continue

        if chtype == "L" and "VL" not in annfiles[chainid]:
            annfiles[chainid] = annfiles[chainid].replace("VH","VL")
        elif chtype == "H" and "VH" not in annfiles[chainid]:
            annfiles[chainid] = annfiles[chainid].replace("VL","VH")

        with open(annfiles[ chainid ], 'w') as outf:
            for n in numbering:
                if n[1] == "-":
                    continue
                outf.write("%s%d%s\t%s\n" % (chtype, n[0][0], n[0][1], n[1]))
