-- | ViennaRNA folding based on an algebraic ring structure. This should
-- combine the goals of few lines of codes, multiple different folding
-- functions and extensibility.
--
-- NOTE Assume that you want '-d 3' for folding with dangles. Then you can just
-- instanciate the folding functions, replacing only those functions where the
-- folding changes based on the new dangle options.
--
-- NOTE compile with: -fno-method-sharing

module BioInf.RNAFold where

import Control.Monad
import Control.Monad.ST

import Biobase.RNA
import Biobase.Types.Ring
import Data.PrimitiveArray
import Biobase.Structure

import BioInf.RNAFold.Functions

import Debug.Trace.Tools
import Debug.Trace


-- | Folding works on unboxed values of a Ring-type for which a FoldFunctions
-- instance does exist. By default, we have this for Energy values. Again, we
-- use a class as we could be interested in probabilistic backtracking or
-- something like that.

type ResultTables a =
  ( Table a -- weak structures
  , Table a -- strong structures
  , Table a -- exactly one component
  , Table a -- one or more components
  , Table a -- complete external structures
  )

type Pairlist = [(Int,Int)]

class (FoldFunctions a) => Fold a where

  fold      :: TurnerTables a -> Primary -> (ResultTables a)
  foldST    :: TurnerTables a -> Primary -> ST s (ResultTables a)
  backtrack :: TurnerTables a -> Primary -> (ResultTables a) -> a -> [(Secondary,a)]

  -- | We have a default instance for folding based on Rings

  fold trnr inp = runST $ foldST trnr inp
  {-# INLINE fold #-}

  foldST trnr inp = do
    let n = snd $ bounds inp
    (weakM,weak)     <- mkTable n
    (strongM,strong) <- mkTable n
    (externM,extern) <- mkTableWith one n
    (mbr1M,mbr1)     <- mkTable n
    (mbrM,mbr)       <- mkTable n
    forM_ [n,n-1..0] $ \i -> forM_ [i,i+1..n] $ \j -> do
      let pIJ = pair inp i j
      when (pIJ/=vpNP&&i+3<j) $ do
        -- weak table
        let hpVal = {-# SCC "hpVal" #-} hairpinOpt trnr inp i j
        let ilVal = {-# SCC "ilVal" #-} ringSumL
              [ largeInteriorLoopOpt trnr inp strong i j
              , tabbedInteriorLoopOpt trnr inp strong i j
              , bulgeLOpt trnr inp strong i j
              , bulgeROpt trnr inp strong i j
              , interior1xnLOpt trnr inp strong i j
              , interior1xnROpt trnr inp strong i j
              ]
        let mbVal = {-# SCC "mbVal" #-} multibranchCloseOpt trnr inp mbr mbr1 i j
        writeM weakM (i,j) $ ringSumL [hpVal,ilVal,mbVal]
        -- strong table
        when (i+5<j) $ do
          let stValW = {-# SCC "stValW" #-} stackOpt trnr inp weak i j
          let stValS = {-# SCC "stValS" #-} stackOpt trnr inp strong i j
          writeM strongM (i,j) $ ringSumL [stValW,stValS]
      -- multibranch loops
      when (i>0&&j<n) $ do
        -- M1
        let mbr1ValS = {-# SCC "mbr1ValS" #-} multibranchIJLoopOpt trnr inp strong i j
        let mbr1ValU = {-# SCC "mbr1ValU" #-} multibranchUnpairedJOpt trnr inp mbr1 i j
        writeM mbr1M (i,j) $ ringSumL [mbr1ValS,mbr1ValU]
        -- M
        let mbrValU  = {-# SCC "mbrValU" #-} multibranchUnpairedJOpt trnr inp mbr i j
        let mbrValS  = {-# SCC "mbrValS" #-} multibranchKJHelixOpt trnr inp strong i j
        let mbrValMS = {-# SCC "mbrValMS" #-} multibranchAddKJHelixOpt trnr inp mbr strong i j
        writeM mbrM (i,j) $ ringSumL [mbrValU,mbrValS,mbrValMS]
    -- fill only part of the F array
    let j=n
    forM_ [n-6,n-7..0] $ \i -> do
      let extUP   = {-# SCC "extUP"   #-} if i<n then extern ! (i+1,j) else zero
      let extStr  = {-# SCC "extStr"  #-} externalLoopOpt trnr inp strong i j
      let extAddL = {-# SCC "extAddL" #-} externalAddLoopOpt trnr inp strong extern i j
      writeM externM (i,j) $ ringSumL [extUP,extStr,extAddL,one] -- always add 'one' as the open chain should always be contained
    return (weak,strong,mbr1,mbr,extern)
  {-# INLINE foldST #-}


-- * Helper functions

-- Create one 2dim - IxTable with default value.

mkTable n = mkTableWith zero n

-- Create one IxTable with user-supplied value.

mkTableWith v n = do
  tM <- fromAssocsM (0,0) (n,n) v []
  t <- unsafeFreezeM tM
  return (tM,t)