module BioInf.RNAFold where
import Control.Monad
import Control.Monad.ST
import Biobase.RNA
import Biobase.Types.Ring
import Data.PrimitiveArray
import Biobase.Structure
import BioInf.RNAFold.Functions
import Debug.Trace.Tools
import Debug.Trace
type ResultTables a =
  ( Table a 
  , Table a 
  , Table a 
  , Table a 
  , Table a 
  )
type Pairlist = [(Int,Int)]
class (FoldFunctions a) => Fold a where
  fold      :: TurnerTables a -> Primary -> (ResultTables a)
  foldST    :: TurnerTables a -> Primary -> ST s (ResultTables a)
  backtrack :: TurnerTables a -> Primary -> (ResultTables a) -> a -> [(Secondary,a)]
  
  fold trnr inp = runST $ foldST trnr inp
  
  foldST trnr inp = do
    let n = snd $ bounds inp
    (weakM,weak)     <- mkTable n
    (strongM,strong) <- mkTable n
    (externM,extern) <- mkTableWith one n
    (mbr1M,mbr1)     <- mkTable n
    (mbrM,mbr)       <- mkTable n
    forM_ [n,n1..0] $ \i -> forM_ [i,i+1..n] $ \j -> do
      let pIJ = pair inp i j
      when (pIJ/=vpNP&&i+3<j) $ do
        
        let hpVal =  hairpinOpt trnr inp i j
        let ilVal =  ringSumL
              [ largeInteriorLoopOpt trnr inp strong i j
              , tabbedInteriorLoopOpt trnr inp strong i j
              , bulgeLOpt trnr inp strong i j
              , bulgeROpt trnr inp strong i j
              , interior1xnLOpt trnr inp strong i j
              , interior1xnROpt trnr inp strong i j
              ]
        let mbVal =  multibranchCloseOpt trnr inp mbr mbr1 i j
        writeM weakM (i,j) $ ringSumL [hpVal,ilVal,mbVal]
        
        when (i+5<j) $ do
          let stValW =  stackOpt trnr inp weak i j
          let stValS =  stackOpt trnr inp strong i j
          writeM strongM (i,j) $ ringSumL [stValW,stValS]
      
      when (i>0&&j<n) $ do
        
        let mbr1ValS =  multibranchIJLoopOpt trnr inp strong i j
        let mbr1ValU =  multibranchUnpairedJOpt trnr inp mbr1 i j
        writeM mbr1M (i,j) $ ringSumL [mbr1ValS,mbr1ValU]
        
        let mbrValU  =  multibranchUnpairedJOpt trnr inp mbr i j
        let mbrValS  =  multibranchKJHelixOpt trnr inp strong i j
        let mbrValMS =  multibranchAddKJHelixOpt trnr inp mbr strong i j
        writeM mbrM (i,j) $ ringSumL [mbrValU,mbrValS,mbrValMS]
    
    let j=n
    forM_ [n6,n7..0] $ \i -> do
      let extUP   =  if i<n then extern ! (i+1,j) else zero
      let extStr  =  externalLoopOpt trnr inp strong i j
      let extAddL =  externalAddLoopOpt trnr inp strong extern i j
      writeM externM (i,j) $ ringSumL [extUP,extStr,extAddL,one] 
    return (weak,strong,mbr1,mbr,extern)
  
mkTable n = mkTableWith zero n
mkTableWith v n = do
  tM <- fromAssocsM (0,0) (n,n) v []
  t <- unsafeFreezeM tM
  return (tM,t)