#!/usr/bin/env python
docstring='''
Get the base pair information based on the iput 3D structure in PDB format.

Usage: 3Dto2D.py pdb.pdb

The output contains the following four parts:
    Title: File name and chain ID
    SEQ:   Nucleotide sequence 
    SS:    Secondary structure
           '<' for paired with downstream base
           '>' for paired with upstream base
           '.' for unpaired
    DIS:   Chain discontinuity. 0 for continuous. 1 for discontinuous.
    CT:    List of base pairs. Nucleotide index starts from 0. 
'''

import sys,os,math
import numpy as np


code_rna=['A','U','G','C']

code_rna_modf={'0A':'A','0U':'U','0G':'G','0C':'C'}

def angle(x1,y1,z1,x2,y2,z2,x3,y3,z3):
  x=np.array([float(x2)-float(x1),float(y2)-float(y1),float(z2)-float(z1)])
  y=np.array([float(x3)-float(x2),float(y3)-float(y2),float(z3)-float(z2)])
  Lx=np.sqrt(x.dot(x))
  Ly=np.sqrt(y.dot(y))
  cos_angle=x.dot(y)/(Lx*Ly)
  angle=np.arccos(cos_angle)
  return angle 

def distant(x1,y1,z1,x2,y2,z2):
  return math.sqrt((float(x1)-float(x2))*(float(x1)-float(x2))+(float(y1)-float(y2))*(float(y1)-float(y2))+(float(z1)-float(z2))*(float(z1)-float(z2))) 


def basepair_type(a,b):
    sign=0
    if a=='U' and b=='A': sign=1
    if a=='A' and b=='U': sign=1
    if a=='C' and b=='G': sign=1
    if a=='G' and b=='C': sign=1
    if a=='U' and b=='G': sign=1
    if a=='G' and b=='U': sign=1
    return sign

def R_basepair(xxx,yyy,resd1,resd2):
   if basepair_type(resd1,resd2)==0: return 0
   disc3min=12.5
   disc3max=14.5 #15.5
   x1='0'
   y1='0'
   z1='0'
   x2='0'
   y2='0'
   z2='0'
   for line in xxx.splitlines():
      resatom=line[12:16].strip()
      if resatom=='C3\'':  #####
        x1=line[30:38].strip()
        y1=line[38:46].strip()
        z1=line[46:54].strip()
   for line in yyy.splitlines():
      resatom=line[12:16].strip()
      if resatom=='C3\'':  #####
        x2=line[30:38].strip()
        y2=line[38:46].strip()
        z2=line[46:54].strip() 
   if distant(x1,y1,z1,x2,y2,z2)>disc3min and distant(x1,y1,z1,x2,y2,z2)<disc3max:
      return 1
   else:return 0
 
def basepair(xxx,yyy,resd1,resd2,HBmax):
  HBmin=1.8
  sign=0
  if(resd1=='G' and resd2=='C'):
     Ggg=xxx
     Ccc=yyy
     sign=1
  if(resd1=='C' and resd2=='G'):
     Ggg=yyy
     Ccc=xxx
     sign=1
  if(resd1=='A' and resd2=='U'):
     Aaa=xxx
     Uuu=yyy
     sign=2
  if(resd1=='U' and resd2=='A'):
     Aaa=yyy
     Uuu=xxx
     sign=2
  if(resd1=='G' and resd2=='U'):
     Ggg=xxx
     Uuu=yyy
     sign=3
  if(resd1=='U' and resd2=='G'):
     Ggg=yyy
     Uuu=xxx
     sign=3
  if sign==0: return 0   
  if sign==1:
    gn1x=-1
    gn1y=-1
    gn1z=-1
    gn2x=-1
    gn2y=-1
    gn2z=-1
    go6x=-1
    go6y=-1
    g06z=-1
    cn4x=-1
    cn4y=-1
    cn4z=-1
    cn3x=-1
    cn3y=-1
    cn3z=-1
    co2x=-1
    co2y=-1
    co2z=-1
    c4x=-1
    c4y=-1
    c4z=-1

    
    for line in Ggg.splitlines():
      resatom=line[12:15].strip()
      if resatom=='N1':  #####
        gn1x=line[30:38].strip()
        gn1y=line[38:46].strip()
        gn1z=line[46:54].strip() 
      if resatom=='N2':  #####
        gn2x=line[30:38].strip()
        gn2y=line[38:46].strip()
        gn2z=line[46:54].strip()
      if resatom=='O6':  #####
        go6x=line[30:38].strip()
        go6y=line[38:46].strip()
        go6z=line[46:54].strip()
      if resatom=='C4':  #####
        c4x=line[30:38].strip()
        c4y=line[38:46].strip()
        c4z=line[46:54].strip() 
      else: continue
    for line in Ccc.splitlines():
      resatom=line[12:15].strip()
      if resatom=='N4':  #####
        cn4x=line[30:38].strip()
        cn4y=line[38:46].strip()
        cn4z=line[46:54].strip() 
      if resatom=='N3':  #####
        cn3x=line[30:38].strip()
        cn3y=line[38:46].strip()
        cn3z=line[46:54].strip()
      if resatom=='O2':  #####
        co2x=line[30:38].strip()
        co2y=line[38:46].strip()
        co2z=line[46:54].strip()
      else: continue
      
    if gn1x==-1 or gn1y==-1 or gn1z==-1 or gn2x==-1 or gn2y==-1 or gn2z==-1 or go6x==-1 or go6y==-1 or go6z==-1 or cn4x==-1 or cn4y==-1 or cn4z==-1 or cn3x==-1 or cn3y==-1 or cn3z==-1 or co2x==-1 or co2y==-1 or co2z==-1:
      return -1

    dis1=distant(gn1x,gn1y,gn1z,cn3x,cn3y,cn3z)
    dis2=distant(gn2x,gn2y,gn2z,co2x,co2y,co2z)
    dis3=distant(go6x,go6y,go6z,cn4x,cn4y,cn4z)
    
    if dis1 <HBmax and dis1 > HBmin and dis2 <HBmax and dis2 > HBmin and dis3 < HBmax and dis3 >HBmin and  angle(c4x,c4y,c4z,gn1x,gn1y,gn1z,cn3x,cn3y,cn3z)<math.pi/3+0.05:
       return 1
    else: return 0
       

  if sign==2:
    an1x=-1
    an1y=-1
    an1z=-1
    an6x=-1
    an6y=-1
    an6z=-1
    un3x=-1
    un3y=-1
    un3z=-1
    uo4x=-1
    uo4y=-1
    uo4z=-1
    c4x=-1
    c4y=-1
    c4z=-1
    for line in Aaa.splitlines():
      resatom=line[12:15].strip()
      if resatom=='N1':  #####
        an1x=line[30:38].strip()
        an1y=line[38:46].strip()
        an1z=line[46:54].strip() 
      if resatom=='N6':  #####
        an6x=line[30:38].strip()
        an6y=line[38:46].strip()
        an6z=line[46:54].strip()
      if resatom=='C4':  #####
        c4x=line[30:38].strip()
        c4y=line[38:46].strip()
        c4z=line[46:54].strip()
      else: continue
    for line in Uuu.splitlines():
      resatom=line[12:15].strip()
      if resatom=='N3':  #####
        un3x=line[30:38].strip()
        un3y=line[38:46].strip()
        un3z=line[46:54].strip() 
      if resatom=='O4':  #####
        uo4x=line[30:38].strip()
        uo4y=line[38:46].strip()
        uo4z=line[46:54].strip()
      else: continue

    if an1x==-1 or an1y==-1 or an1z==-1 or an6x==-1 or an6y==-1 or an6z==-1 or un3x==-1 or un3y==-1 or un3z==-1 or uo4x==-1 or uo4y==-1 or uo4z==-1:
      return -1

    dis1=distant(an1x,an1y,an1z,un3x,un3y,un3z)
    dis2=distant(an6x,an6y,an6z,uo4x,uo4y,uo4z)
    if dis1 <HBmax and dis1 > HBmin and dis2 <HBmax and dis2 > HBmin and angle(c4x,c4y,c4z,an1x,an1y,an1z,un3x,un3y,un3z) < math.pi/3+0.05:
      return 1 
    else: return 0


  if sign==3:
    gn1x=-1
    gn1y=-1
    gn1z=-1
    go6x=-1
    go6y=-1
    go6z=-1
    uo2x=-1
    uo2y=-1
    uo2z=-1
    un3x=-1
    un3y=-1
    un3z=-1
    c6x=-1
    c6y=-1
    c6z=-1
    for line in Ggg.splitlines():
      resatom=line[12:15].strip()
      if resatom=='N1':  #####
        gn1x=line[30:38].strip()
        gn1y=line[38:46].strip()
        gn1z=line[46:53].strip() 
      if resatom=='O6':  #####
        go6x=line[30:38].strip()
        go6y=line[38:46].strip()
        go6z=line[46:54].strip()

      else: continue
    for line in Uuu.splitlines():
      resatom=line[12:15].strip()
      if resatom=='O2':  #####
        uo2x=line[30:38].strip()
        uo2y=line[38:46].strip()
        uo2z=line[46:54].strip() 
      if resatom=='N3':  #####
        un3x=line[30:38].strip()
        un3y=line[38:46].strip()
        un3z=line[46:54].strip()
      if resatom=='C6':  #####
        c6x=line[30:38].strip()
        c6y=line[38:46].strip()
        c6z=line[46:54].strip()
      else: continue
    if gn1x==-1 or gn1y==-1 or go6z==-1 or go6x==-1 or go6y==-1 or go6z==-1 or un3x==-1 or un3y==-1 or un3z==-1 or uo2x==-1 or uo2y==-1 or uo2z==-1:
      return -1  
 
    dis1=distant(gn1x,gn1y,gn1z,uo2x,uo2y,uo2z)
    dis2=distant(go6x,go6y,go6z,un3x,un3y,un3z)
    if dis1 <HBmax and dis1 > HBmin and dis2 <HBmax and dis2 > HBmin and angle(c6x,c6y,c6z,un3x,un3y,un3z,go6x,go6y,go6z) < math.pi/3+0.05: 
      return 1
    else: return 0 


def chain_into_resds(chaintxt): 
  resd_allatom=[]
  index=[]
  resd_name=[]
  resd_c3x=[]
  resd_c3y=[]
  resd_c3z=[]
  rnaseq=[]
  ss=''
  chain_name=''
  lm='' 
  for line in chaintxt.splitlines():
      if not line.startswith("ATOM"): continue
      if line[17:20].strip() not in code_rna:
        print('Warning: Chain '+line[21:22].strip()+' is not standard RNA, Please input standard RNA')
        break
      if not line[21:22].strip()==chain_name and not chain_name=='': 
             continue
      chain_name=line[21:22].strip() 
      if line[16:17]=='B':continue
      if line[22:27].strip() not in index:
        index.append(line[22:27].strip())
        resd_name.append(line[17:20].strip()) 
        if len(index)>1:
          resd_allatom.append(ss)
        ss=line[0:55]+'\n'      
      else: ss=ss+line[0:55]+'\n'
  resd_allatom.append(ss)
  
  for ll in range(len(index)-1,-1,-1):
    lm=0
    for line in resd_allatom[ll].splitlines():
      if line[12:16].strip()=='C3\'':
        lm=1
    if lm==0:################################################################################################delete resi without c3
       del resd_name[ll]
       del index[ll]
       del resd_allatom[ll]

  for ll in range(0,len(index)):
     for line in resd_allatom[ll].splitlines():
       if line[12:16].strip()=='C3\'':
         resd_c3x.append(float(line[30:38].strip()))
         resd_c3y.append(float(line[38:46].strip()))
         resd_c3z.append(float(line[46:54].strip()))
  return chain_name,index,resd_name,resd_allatom,resd_c3x,resd_c3y,resd_c3z  

def calbp(bp,ii,jj,ll):
  nbp=1
  
  if ii-1>-1 and jj+1<ll and bp[ii-1][jj+1]==1:
     nbp=nbp+1
     if  ii-2>-1 and jj+2<ll and bp[ii-2][jj+2]==1:  
       nbp=nbp+1
       if  ii-3>-1 and jj+3<ll and bp[ii-3][jj+3]==1:
         nbp=nbp+1

  if  jj-1>ii+1  and bp[ii+1][jj-1]==1:
     nbp=nbp+1
     if  jj-2>ii+2  and bp[ii+2][jj-2]==1:
       nbp=nbp+1
       if  jj-3>ii+3 and bp[ii+3][jj-3]==1:
         nbp=nbp+1

  if ii-2>-1 and jj+1<ll and bp[ii-2][jj+1]==1 and nbp==1:
     nbp=nbp+1
  if ii-1>-1 and jj+2<ll and bp[ii-1][jj+2]==1 and nbp==1:
     nbp=nbp+1

  if ii+2<jj-1 and bp[ii+2][jj-1]==1 and nbp==1:
     nbp=nbp+1
  if ii+1<jj-2 and bp[ii+1][jj-2]==1 and nbp==1:
     nbp=nbp+1
    
  return nbp
  

def contactBP(pdbid,pdbtxt): 
  resd_allatom=[]
  index=[]
  resd_name=[]
  resd_c1x=[]
  resd_c1y=[]
  resd_c1z=[]
  rnaseq=[]
  SS=[]
  bpconfir=[]
  contactbp=[]
  contactin=[]
  
  bpcheck=[]
  
   
  C_name,index,resd_name,resd_allatom,resd_c3x,resd_c3y,resd_c3z=chain_into_resds(pdbtxt)

  if C_name=='': 
       return -1,' ',' ',' ',' ',' ',' '
  for iii in range(0,len(index)): 
    bpconfir.append(0)
    contactin.append(0)
  for iii in range(0,len(index)): 
    contactbp.append(list(contactin))
    bpcheck.append(list(contactin))

  for iii in range(0,len(index)):
    for jjj in range(0,len(index)):  
      if jjj < iii+4: continue
      if distant(resd_c3x[iii],resd_c3y[iii],resd_c3z[iii],resd_c3x[jjj],resd_c3y[jjj],resd_c3z[jjj]) < 10: continue
      if basepair_type(resd_name[iii],resd_name[jjj])==0: continue
      if basepair(resd_allatom[iii],resd_allatom[jjj],resd_name[iii],resd_name[jjj],4)==1 or basepair(resd_allatom[iii],resd_allatom[jjj],resd_name[iii],resd_name[jjj],4)==2:
          bpcheck[iii][jjj]=1
  
  for iii in range(0,len(index)):
    for jjj in range(0,len(index)):
      if bpcheck[iii][jjj]==0:continue
      if bpconfir[iii]>0 or bpconfir[jjj]>0:continue
      nbp=calbp(bpcheck,iii,jjj,len(index))
      sign=0
      for i in range(0,len(index)):
         if bpcheck[i][jjj]==1:
            nbp1=calbp(bpcheck,i,jjj,len(index))
            if nbp1>nbp: 
              sign=1
              break        
         if bpcheck[iii][i]==1: 
            nbp2=calbp(bpcheck,iii,i,len(index))      
            if nbp2>nbp:    
               sign=2
               break
       
      if sign==0:
         bpconfir[iii]=1
         bpconfir[jjj]=2
         contactbp[iii][jjj]=1
      if sign==1:
          bpconfir[i]=1
          bpconfir[jjj]=2
          contactbp[i][jjj]=1
      if sign==2:
           bpconfir[iii]=1
           bpconfir[i]=2
           contactbp[i][i]=1

######add terminal bp:#####################################################################################################################################################
########################################
  for iii in range(0,len(index)-2):
    for jjj in range(iii+1,len(index)):
      if contactbp[iii][jjj]>0 and bpconfir[iii+1]==0 and bpconfir[jjj-1]==0:
        if R_basepair(resd_allatom[iii+1],resd_allatom[jjj-1],resd_name[iii+1],resd_name[jjj-1])==1:
           contactbp[iii+1][jjj-1]=1
           bpconfir[iii+1]=1
           bpconfir[jjj-1]=2
       
      if iii>0 and jjj+1<len(index):
        if contactbp[iii][jjj]>0 and bpconfir[iii-1]==0 and bpconfir[jjj+1]==0:
          if R_basepair(resd_allatom[iii-1],resd_allatom[jjj+1],resd_name[iii-1],resd_name[jjj+1])==1:
           contactbp[iii-1][jjj+1]=1
           bpconfir[iii-1]=1
           bpconfir[jjj+1]=2
####################################### 

##############discontious chain
  
  disc=[]
  for iii in range(0,len(index)-1):
     if distant(resd_c3x[iii],resd_c3y[iii],resd_c3z[iii],resd_c3x[iii+1],resd_c3y[iii+1],resd_c3z[iii+1]) > 10:
       disc.append('1')
     else: disc.append('0')
  disc.append('0')
  
  ct='\n'
  bpl=0
  chain_allaltom=''
  for iii in range(0,len(index)):
     rnaseq.append(resd_name[iii])
     SS.append(str(bpconfir[iii]))
     for jjj in range(0,len(index)):
       if contactbp[iii][jjj]==1 or contactbp[iii][jjj]==2:
         ct=ct+'BP:'+str(iii).rjust(8,' ')+str(jjj).rjust(8,' ')+'\n'
  return bpl,C_name,''.join(rnaseq),''.join(SS),''.join(disc),ct,''.join(resd_allatom)+'TER    '+'\n'



       
def selectrnachain(pdbid,pdb_cont):
  structinfo=[]
  seqinfoii=[]
  chainID=[]
  seqall=[]
  seqinfo=[]
  chainlist=pdb_cont.split('\nTER ')
  chain_name=[]
  chain_coor=[]
  needck=[]
  CT=[]
  joint_chainID=''
  seqall_char=[]
  chdis=[]
  for ii in range(0,len(chainlist)):
    seqinfo.append(0)
    structinfo.append(0)
    chain_coor.append(0)
    CT.append('0')
    chdis.append('0')
  for ii in range(0,len(chainlist)-1):
    x1,x2,x3,x4,x5,x6,x7=contactBP(pdbid,chainlist[ii])
    if x1==-1: continue
    Ch_name=x2
    seqinfo[ii]=x3
    structinfo[ii]=x4.replace('0','.').replace('1','<').replace('2','>')
    chdis[ii]=x5
    CT[ii]=x6
    chain_coor[ii]=x7
    chain_name.append(Ch_name)
    seqall='\nTitle: '+pdbid+' ChainID: '+Ch_name+'\nSEQ: '+seqinfo[ii]+'\nSS:  '+structinfo[ii]+'\n'+'DIS: '+chdis[ii]+'\nCT:  '+CT[ii]
    print(seqall)

def readpdb(pdbf):
  filename=pdbf
  fp=open(filename,'rU')
  pdbcont=fp.read()
  fp.close()
  return(pdbcont)



if __name__=='__main__':
  
  if len(sys.argv)<2:
      sys.stderr.write(docstring)
      exit()
  pdbfile=sys.argv[1]
  if not os.path.exists(pdbfile):
     sys.stderr.write(docstring)   
     exit()
  
  if pdbfile.endswith('.pdb'):
       pdbcont=readpdb(pdbfile)
       infor=selectrnachain(str(pdbfile),pdbcont)