#!/usr/bin/python
# -*- coding: utf-8 -*-
#

import sys
import os.path
from array import array

import sqlite3 as sql

import numpy

# Make sure we can import stuff from this file's directory
sys.path.append(os.path.abspath(os.path.dirname(sys.argv[0])))
sys.path.extend(os.environ['PATH'].split(':'))

from prosci.util.residue import ResidueList
from prosci.util.ali import Ali
from prosci.shell import Params

from prosci.loops.loopmodel import ANCHOR_LENGTH, add_oxygens, get_dihedral_class, is_structure_ok, describe_anchors


def write_structure(f, struc):
  for r in struc:
    f.write(">\n")
    f.write(str(r))



terminal_cutoff = 1


params = Params(allowed=["force"], withargument=["terminal", "maxb", "minlen", "maxlen"])

if len(params.args) < 2:
  print "USAGE: %s <pyfread_db> <pdb_file ...>" % (params.scriptname)
  print """
    <pyfread_db>    Root directory of the PyFREAD database
    <pdb_file ...>  One or more PDB files to be added to the database
  
  OPTIONS:
      --minlen NUM      Minimum loop length (default: 3)
      --maxlen NUM      Maximum loop length (default: 30)
      --maxb NUM        Set maximum B factor per residue to NUM.
      --terminal NUM    Ignore the first and last NUM residues (default: 5).
      --force           Insert loops into database, even if TEM file exists.
  """
  sys.exit(1)


min_loop_length = int(params.getOpt("minlen", 3))
max_loop_length = int(params.getOpt("maxlen", 30))
terminal_cutoff = int(params.getOpt("terminal", 5)) # Skip this many residues at the N and C termini
assert terminal_cutoff >= 1
max_b = float(params.getOpt("maxb", 10000)) # Max average b factor per loop
force = params.isOpt("force")


dbdir = params.args[0]
strucdir = os.path.join(dbdir, "structures")


structure_files = params.args[1:]
structure_files.sort()



if not os.path.isdir(dbdir):
  os.mkdir(dbdir)
if not os.path.isdir(strucdir):
  os.mkdir(strucdir)



database_connections=[]
database_cursors=[]
for loop_length in xrange(min_loop_length, max_loop_length+1):
  dbfile = os.path.join(dbdir, "length%d.sqlite"%loop_length)
  conn = sql.connect(dbfile)
  cursor = conn.cursor()
  database_cursors.append(cursor)
  database_connections.append(conn)
  
  cursor.execute('''CREATE TABLE IF NOT EXISTS loops (dihedral text, sequence text, pdbcode text, start int, casep1 real, casep2 real, casep3 real, casep4 real)''')



try:

  for ifile, fname in enumerate(structure_files):
    
    strucname = os.path.splitext(os.path.basename(fname))[0]
    if strucname.lower().startswith("pdb"):
      strucname = strucname[3:]
    assert strucname
    strucname = strucname[:4].lower() + strucname[4:]
    
    destination = os.path.join(strucdir, strucname[1:3])
    struc_file_destination = os.path.join(destination, strucname+".atm")
    annotation_file_destination = os.path.join(destination, strucname+".tem")
        
    
    # Skip stuff we've already inserted into the database
    if not force and os.path.exists(annotation_file_destination):
      continue
    
    
    if os.path.exists(struc_file_destination):
      residues = ResidueList(struc_file_destination)
    else:
      residues = ResidueList(fname)
      if len(residues) < 5:
        continue
      if not residues:
        continue
      if not os.path.isdir(destination):
        os.mkdir(destination)
      add_oxygens(residues)
      f = open(struc_file_destination, "w")
      try:
        write_structure(f, residues)
      finally:
        f.close()
    
    
    print "%d/%d : %s" % (ifile+1, len(params.args)-1, strucname)
    

    sequence = residues.get_seq()
    assert len(sequence) == len(residues)

    
    if os.path.exists(annotation_file_destination):
      annotation = Ali(annotation_file_destination)
      dihedrals = annotation[0]["FREAD dihedral class"].seq
    else:
      # Dihedral angles. Also serves as a mask - residues marked '?' are excluded from the database.
      dihedrals = array('c', "?")
      for i in xrange(1, len(residues)-1):
        if not is_structure_ok(residues, i-1, i+2) or sequence[i] == 'X':
          x = '?'
        else:
          try:
            x = str(get_dihedral_class(residues, i))
          except ValueError:
            # Math domain error, if computing angles between overlayed atoms
            x = '?'
        dihedrals.append(x)
      dihedrals.append("?")
      dihedrals = dihedrals.tostring()
      assert len(dihedrals) == len(residues)
    
    
#     struc_coords_list = [r.CA.xyz for r in residues]
    
    
    for loop_length in xrange(min_loop_length, max_loop_length+1):
      cursor = database_cursors[loop_length-min_loop_length]
      totallen = loop_length + 2*ANCHOR_LENGTH
      
      # Loop over all possible fragments within the protein
      for start in xrange(terminal_cutoff, len(residues)-terminal_cutoff-totallen):
        end = start + totallen
        loopstart = start+ANCHOR_LENGTH
        loopend = end-ANCHOR_LENGTH
        
        loopdihedrals = dihedrals[loopstart:loopend]
    
        # Skip bad loops
        if '?' in dihedrals[start:end]:
          continue
        
        loopseq = sequence[loopstart:loopend]
        
#         struc_coords = struc_coords_list[start:end]
        
#         b_factor = 0.0
#         for i in xrange(loopstart, loopend):
#           b_factor += residues[i].CA.b
#         b_factor /= loop_length
        
        # Descsribe anchor
        anchor_description, transform = describe_anchors(residues[start:loopstart], residues[loopend:end], loop_length)
  
        casep1 = anchor_description.C1[0]
        casep2 = numpy.linalg.norm(anchor_description.C2)
        casep3 = numpy.linalg.norm(anchor_description.N1 - anchor_description.C1)
        casep4 = numpy.linalg.norm(anchor_description.N1 - anchor_description.C2)
        
        cursor.execute("INSERT INTO loops VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (loopdihedrals, loopseq, strucname, start, casep1, casep2, casep3, casep4))
    
    
    if not os.path.exists(annotation_file_destination):
      f = open(annotation_file_destination, "w")
      try:
        f.write(">%s\nsequence\n%s\n>%s\nFREAD dihedral class\n%s\n"%(strucname, sequence, strucname, dihedrals))
      finally:
        f.close()
    
finally:
  for cursor, conn in zip(database_connections, database_cursors):
#     cursor.execute("CREATE TABLE loops AS SELECT * FROM loops_unsorted ORDER BY dihedral")
    cursor.commit()
    conn.close()
