module BioInf.RNAFold.Functions
  ( FoldFunctions (..)
  , Table
  , TurnerTables
  , ringSumL
  , pair
  , riap
  , ringProductL
  , tabbedInteriorLoopDistances
  ) where
import qualified Data.Vector.Unboxed as VU
import Control.Exception (assert)
import Data.List (foldl')
import qualified Data.Map as M
import Biobase.RNA
import Biobase.Turner.Tables
import Biobase.Types.Ring
import Biobase.Vienna
import Data.PrimitiveArray
import Data.Primitive.Types
import Debug.Trace.Tools
type Cell = (Int,Int)
type Table a = PrimArray Cell a
type TurnerTables a = Turner2004 ViennaPair Nucleotide a
class (Show a, Ring a, VU.Unbox a, Prim a) => FoldFunctions a where
  stackOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  stackIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  hairpinOpt  :: TurnerTables a -> Primary -> Int -> Int -> a
  hairpinIdx  :: TurnerTables a -> Primary -> Int -> Int -> [a]
  largeInteriorLoopOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  largeInteriorLoopIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  tabbedInteriorLoopOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  tabbedInteriorLoopIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  bulgeLOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  bulgeLIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  bulgeROpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  bulgeRIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  interior1xnLOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  interior1xnLIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  interior1xnROpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  interior1xnRIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  multibranchIJLoopOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  multibranchIJLoopIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  multibranchUnpairedJOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  multibranchUnpairedJIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  multibranchKJHelixOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  multibranchKJHelixIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Int,a)]
  multibranchAddKJHelixOpt  :: TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> a
  multibranchAddKJHelixIdx  :: TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> [(Int,a)]
  multibranchCloseOpt  :: TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> a
  multibranchCloseIdx  :: TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> [(Int,a)]
  externalLoopOpt  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> a
  externalLoopIdx  :: TurnerTables a -> Primary -> Table a -> Int -> Int -> [(Cell,a)]
  externalAddLoopOpt  :: TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> a
  externalAddLoopIdx  :: TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> [(Int,a)] 
  
  
  calcNinio :: a -> a -> Int -> a
  
  
  
  calcTermAU :: a -> ViennaPair -> a 
  
  
  
  calcLargeLoop :: Int -> a
  
  
  
  
  
  stackOpt trnr inp tbl i j =
    VU.foldl' (.+.) zero $ stackBase trnr inp tbl i j
  hairpinOpt trnr inp i j =
    VU.foldl' (.+.) zero $ hairpinBase trnr inp i j
  largeInteriorLoopOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ largeInteriorLoopBase trnr inp strong i j
  tabbedInteriorLoopOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ tabbedInteriorLoopBase trnr inp strong i j
  bulgeLOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ bulgeLBase trnr inp strong i j
  bulgeROpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ bulgeRBase trnr inp strong i j
  interior1xnLOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ interior1xnLBase trnr inp strong i j
  interior1xnROpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ interior1xnRBase trnr inp strong i j
  multibranchIJLoopOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ multibranchIJLoopBase trnr inp strong i j
  multibranchUnpairedJOpt trnr inp mtable i j =
    VU.foldl' (.+.) zero $ multibranchUnpairedJBase trnr inp mtable i j
  multibranchKJHelixOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ multibranchKJHelixBase trnr inp strong i j
  multibranchAddKJHelixOpt trnr inp table strong i j =
    VU.foldl' (.+.) zero $ multibranchAddKJHelixBase trnr inp table strong i j
  multibranchCloseOpt trnr inp m m1 i j =
    VU.foldl' (.+.) zero $ multibranchCloseBase trnr inp m m1 i j
  externalLoopOpt trnr inp strong i j =
    VU.foldl' (.+.) zero $ externalLoopBase trnr inp strong i j
  externalAddLoopOpt trnr inp strong extern i j =
    VU.foldl' (.+.) zero $ externalAddLoopBase trnr inp strong extern i j
  
  
  
  
  
  
  stackIdx trnr inp tbl i j =
    [((i+1,j1),stackOpt trnr inp tbl i j)]
  hairpinIdx trnr inp i j =
    VU.toList $ hairpinBase trnr inp i j
  largeInteriorLoopIdx trnr inp strong i j =
    VU.toList $ VU.zip (interiorLoopIndices i j) (largeInteriorLoopBase trnr inp strong i j)
  tabbedInteriorLoopIdx trnr inp strong i j =
    VU.toList $ VU.zip (tabbedInteriorLoopIndices i j) $ tabbedInteriorLoopBase trnr inp strong i j
  bulgeLIdx trnr inp strong i j =
    VU.toList $ VU.zip (VU.map (\k -> (i+1+k,j1)) . uncurry VU.enumFromN $ bulgeLimit i j) $ bulgeLBase trnr inp strong i j
  bulgeRIdx trnr inp strong i j =
    VU.toList $ VU.zip (VU.map (\k -> (i+1,j1k)) . uncurry VU.enumFromN $ bulgeLimit i j) $ bulgeRBase trnr inp strong i j
  interior1xnLIdx trnr inp strong i j =
    VU.toList $ VU.zip (VU.map (\k -> (i+1+k,j2)) $ uncurry VU.enumFromN $ iloop1xnLimit i j) $ interior1xnLBase trnr inp strong i j
  interior1xnRIdx trnr inp strong i j =
    VU.toList $ VU.zip (VU.map (\k -> (i+2,j1k)) $ uncurry VU.enumFromN $ iloop1xnLimit i j) $ interior1xnRBase trnr inp strong i j
  multibranchIJLoopIdx trnr inp strong i j =
    VU.toList $ VU.zip (VU.singleton (i,j)) $ multibranchIJLoopBase trnr inp strong i j
  multibranchUnpairedJIdx trnr inp mtable i j =
    VU.toList $ VU.zip (VU.singleton (i,j1)) $ multibranchUnpairedJBase trnr inp mtable i j
  multibranchKJHelixIdx trnr inp strong i j =
    VU.toList $ VU.zip (uncurry VU.enumFromN $ multibranchKJHelixLimit i j) $ multibranchKJHelixBase trnr inp strong i j
  multibranchCloseIdx trnr inp m m1 i j = 
    VU.toList $ VU.zip (uncurry VU.enumFromN $ multibranchCloseLimit i j) $ multibranchCloseBase trnr inp m m1 i j
  multibranchAddKJHelixIdx trnr inp table strong i j =
    VU.toList $ VU.zip (uncurry VU.enumFromN $ multibranchAddKJHelixLimit i j) $ multibranchAddKJHelixBase trnr inp table strong i j
  externalLoopIdx trnr inp strong i j =
    VU.toList $ VU.singleton ((i,j), externalLoopOpt trnr inp strong i j)
  externalAddLoopIdx trnr inp strong extern i j =
    VU.toList $ VU.zip (uncurry VU.enumFromN $ externalAddLoopLimit i j) $ externalAddLoopBase trnr inp strong extern i j
hairpinBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Int -> Int -> VU.Vector a
hairpinBase Turner2004{..} inp i j = VU.singleton go where
  go
    | pIJ==vpNP || l<3 = zero
    | l<=6, Just v <- s `M.lookup` hairpinLookup = v
    | l==3      = (hairpinL ! l) 
    | l>=31     = (hairpinL ! 30) .*. (hairpinMM ! (pIJ,bI,bJ)) .*. (calcLargeLoop l)
    | otherwise =  (hairpinL ! l) .*. (hairpinMM ! (pIJ,bI,bJ))
  l = ji1
  s = assert (i>=0 && checkBounds inp j) $ [inp ! k | k <- [i..j1]]
  bI = inp ! (i+1)
  bJ = inp ! (j1)
  pIJ = pair inp i j
  tAU = calcTermAU termAU pIJ
stackBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Int -> Int -> VU.Vector a
stackBase Turner2004{..} inp tbl i j = VU.singleton $ go (i+1,j1) where
  pIJ = pair inp i j
  go (k,l)
    | pIJ==vpNP || pKL==vpNP || isZero tE
    = zero
    | otherwise
    = tE .*. ( stack ! (pIJ,pKL))
    where
      pKL = riap inp k l
      tE  = tbl ! (k,l)
tabbedInteriorLoopBase :: (Show a, FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Int -> Int -> VU.Vector a
tabbedInteriorLoopBase Turner2004{..} inp strong i j = res where
  res = VU.map (\(k,l) -> (strong ! (k,l)) .*. (ftabbed (ki1,jl1) (k,l))) ili
  ili = tabbedInteriorLoopIndices i j
  ftabbed (di,dj) (k,l)
    | ds==0 && dl==1 = bulgeL!1 .*. stack!(pIJ,pKL) 
    | ds==1 && dl==1 = iloop1x1 ! (pIJ,pKL,bI,bJ)
    | di==1 && dj==2 = iloop1x2 ! (pIJ,pKL,bI,bL,bJ)
    | di==2 && dj==1 = iloop1x2 ! (pIJ,pKL,bJ,bI,bK)
    | ds==2 && dl==2 = iloop2x2 ! (pIJ,pKL,bI,bK,bL,bJ)
    | ds==2 && dl==3 = iloop2x3MM!(pIJ,bI,bJ) .*. iloop2x3MM ! (pKL,bL,bK) .*. iloopL ! 5 .*. ninio
    where
      pKL = riap inp k l
      bK  = inp ! (k1)
      bL  = inp ! (l+1)
      ds  = min di dj
      dl  = max di dj
  pIJ = pair inp i j
  bI  = inp ! (i+1)
  bJ  = inp ! (j1)
largeInteriorLoopBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Int -> Int -> VU.Vector a
largeInteriorLoopBase Turner2004{..} inp strong i j = res where
  res = VU.map (\(di,dj) ->
                    ijmm .*.
                    (strong ! (i+di,jdj)) .*.
                    (iloopL ! (di+dj2)) .*. 
                    (iloopMM ! (riap inp (i+di) (jdj), inp ! (i+di1), inp ! (jdj+1))) .*.
                    calcNinio maxNinio ninio (abs $ didj)
                ) didjs
  
  
  didjs = interiorLoopDistances i j
  
  
  
  ijmm = iloopMM ! (pIJ,bI,bJ)
  pIJ  = pair inp i j
  bI   = inp ! (i+1)
  bJ   = inp ! (j1)
bulgeLBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Int -> Int -> VU.Vector a
bulgeLBase Turner2004{..} inp strong i j = res where
  res = VU.map (\k ->
                  strong ! (i+1+k,j1) .*.
                  bulgeL ! k .*.
                  calcTermAU termAU (riap inp (i+1+k) (j1)) .*.
                  tAUij
                ) . uncurry VU.enumFromN $ bulgeLimit i j
  tAUij = calcTermAU termAU $ pair inp i j
bulgeRBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Int -> Int -> VU.Vector a
bulgeRBase Turner2004{..} inp strong i j = res where
  res = VU.map (\k ->
                  strong ! (i+1,j1k) .*.
                  bulgeL ! k .*.
                  calcTermAU termAU (riap inp (i+1) (j1k)) .*.
                  tAUij
                ) . uncurry VU.enumFromN $ bulgeLimit i j
  tAUij  = calcTermAU termAU $ pair inp i j
interior1xnLBase Turner2004{..} inp strong i j = res where
  res = VU.map (\k ->
                  strong ! (i+1+k,j2) .*.
                  iloopL ! (k+1) .*.
                  iloop1xnMM ! (riap inp (i+1+k) (j2), inp ! (i+k), inp ! (j1)) .*.
                  calcNinio maxNinio ninio (k1) .*.
                  pIJmm
                ) . uncurry VU.enumFromN $ iloop1xnLimit i j
  pIJmm = iloop1xnMM ! (pair inp i j, inp ! (i+1), inp ! (j1))
interior1xnRBase Turner2004{..} inp strong i j = res where
  res = VU.map (\k ->
                  strong ! (i+2,j1k) .*.
                  iloopL ! (k+1) .*.
                  iloop1xnMM ! (riap inp (i+2) (j1k), inp ! (jk), inp ! (i+1)) .*.
                  calcNinio maxNinio ninio (k1) .*.
                  pIJmm
                ) . uncurry VU.enumFromN $ iloop1xnLimit i j
  pIJmm = iloop1xnMM ! (pair inp i j, inp ! (i+1), inp ! (j1))
multibranchCloseBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> VU.Vector a
multibranchCloseBase Turner2004{..} inp m m1 i j = res where
  res = VU.map (\k ->
                  m ! (i+1,k) .*.
                  m1 ! (k+1,j1) .*.
                  ijmm .*.
                  mbcl
                ) . uncurry VU.enumFromN $ multibranchCloseLimit i j
  ijmm =  multiMM ! (pIJ,bJ,bI)
  mbcl =  multiOffset .*. multiHelix
  pIJ  = riap inp i j
  bI   = inp ! (i+1)
  bJ   = inp ! (j1)
multibranchIJLoopBase Turner2004{..} inp strong i j = res where
  res = VU.singleton $ strong ! (i,j) .*. mbrhlx .*. mbrmm
  mbrhlx = multiHelix
  mbrmm  =  multiMM ! (pIJ,bI,bJ) 
  pIJ    = pair inp i j
  bI     = inp ! (i1)
  bJ     = inp ! (j+1)
multibranchUnpairedJBase Turner2004{..} inp mtable i j = res where
  res = VU.singleton $ mtable ! (i,j1) .*. mbrup
  mbrup = multiNuc
multibranchKJHelixBase  :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Int -> Int -> VU.Vector a
multibranchKJHelixBase Turner2004{..} inp strong i j = res where
  res = VU.map (\k ->
          multiNuc .^. (ki) .*.
          strong ! (k,j) .*.
          multiHelix .*.
          multiMM ! (pair inp k j, inp ! (k1), inp ! (j+1))
        ) . uncurry VU.enumFromN $ multibranchKJHelixLimit i j
multibranchAddKJHelixBase Turner2004{..} inp table strong i j = res where
  res = VU.map (\k ->
                  table ! (i,k) .*.
                  strong ! (k+1,j) .*.
                  multiMM ! (pair inp (k+1) j, inp ! k, inp ! (j+1))
                ) . uncurry VU.enumFromN $ multibranchAddKJHelixLimit i j
externalLoopBase Turner2004{..} inp strong i j = res where
  res = VU.singleton $ strong ! (i,j) .*. mm .*. tAU
  n = snd $ bounds inp
  pIJ = pair inp i j
  bI = inp ! (i1)
  bJ = inp ! (j+1)
  tAU = calcTermAU termAU pIJ
  mm
    | i>0&&j<n  = extMM ! (pIJ,bI,bJ)
    | i>0       = dangle5 ! (pIJ,bI)
    | j<n       = dangle3 ! (pIJ,bJ)
    | otherwise = one
externalAddLoopBase :: (FoldFunctions a) => TurnerTables a -> Primary -> Table a -> Table a -> Int -> Int -> VU.Vector a
externalAddLoopBase trnr@Turner2004{..} inp strong extern i j = res where
  res = VU.map (\k ->
                  externalLoopOpt trnr inp strong i k .*.
                  extern ! (k+1,j)
                ) . uncurry VU.enumFromN $ externalAddLoopLimit i j
pair :: Primary -> Int -> Int -> ViennaPair
pair inp i j
  = assert (checkBounds inp i && checkBounds inp j)
  $ toPair (inp `unsafeIndex` i) (inp `unsafeIndex` j)
riap inp i j
  = assert (i>=0 && j>=0 && checkBounds inp i && checkBounds inp j)
  $ toPair (inp ! j) (inp ! i)
ringSum :: (Ring a, VU.Unbox a) => VU.Vector a -> a
ringSum v = VU.foldl' (.+.) zero v
ringSumL :: (Ring a, VU.Unbox a) => [a] -> a
ringSumL v = foldl' (.+.) zero v
ringProduct :: (Ring a, VU.Unbox a) => VU.Vector a -> a
ringProduct v = VU.foldl' (.*.) one v
ringProductL :: (Ring a, VU.Unbox a) => [a] -> a
ringProductL v = foldl' (.*.) one v
interiorLoopDistances i j =
  VU.concatMap (
    \d -> VU.map (\d' -> (d',dd'))  
          $ VU.enumFromN 3 (d5))    
  $ VU.enumFromN 8 (min 23 (ji13)) 
interiorLoopIndices :: Int -> Int -> VU.Vector (Int,Int)
interiorLoopIndices !i !j = VU.map (\(k,l) -> (i+k,jl)) $ interiorLoopDistances i j
tabbedInteriorLoopDistances :: Int -> Int -> VU.Vector (Int,Int)
tabbedInteriorLoopDistances i j
  | ji>=8    = VU.fromList [(0,1),(1,0),(1,1),(1,2),(2,1),(2,2),(2,3),(3,2)]
  | otherwise = VU.empty
tabbedInteriorLoopIndices i j = VU.map (\(di,dj) -> (i+di+1,jdj1)) $ tabbedInteriorLoopDistances i j
bulgeLimit :: Int -> Int -> (Int,Int)
bulgeLimit i j = (2,min 29 $ ji9)
iloop1xnLimit :: Int -> Int -> (Int,Int)
iloop1xnLimit i j = (3,min 26 $ ji10)
multibranchCloseLimit :: Int -> Int -> (Int,Int)
multibranchCloseLimit i j = (i+1,ji2) 
multibranchAddKJHelixLimit :: Int -> Int -> (Int,Int)
multibranchAddKJHelixLimit i j = (i+1,ji1)
multibranchKJHelixLimit :: Int -> Int -> (Int,Int)
multibranchKJHelixLimit i j = (i,ji)
externalAddLoopLimit :: Int -> Int -> (Int,Int)
externalAddLoopLimit i j = (i+5,ji5)