{-# LANGUAGE TupleSections #-}
{-# LANGUAGE RecordWildCards #-}

-- |
--
-- TODO bonus system for matching "()" bracking still broken?!

module BioInf.MCFoldDP where

import Control.Arrow (first,second)
import Control.Monad
import Control.Monad.ST
import Data.List (find,sort)
import Data.Tuple.Select (sel3)
import qualified Data.Vector.Unboxed as VU

import Biobase.DataSource.MCFold
import Biobase.Constants
import Biobase.RNA
import Biobase.RNA.Hashes
import Biobase.RNA.NucBounds
import Biobase.Structure
import Biobase.Structure.Constraint
import Data.PrimitiveArray
import Data.PrimitiveArray.Ix



-- | Folding wrapper

fold :: MotifDB -> Primary -> Constraint -> Tables
fold db inp constr = runST $ foldST db inp constr

-- | Folding in the ST monad. the number of dncm tables is the same as the
-- number of known double NCMs.

foldST :: MotifDB -> Primary -> Constraint -> ST s Tables
foldST db inp constr = do
  unless (VU.length inp == (VU.length . unConstraint $ constr)) $ do error $ "foldST: inp / constr length mismatch"
  let n = VU.length inp -1
  (sncm,sncmM)     <- mkTable2 n
  (dncms,dncmMs)   <- liftM unzip . mapM (const $ mkTable2 n) $ VU.toList knownDoubleNCM
  (tmp,tmpM)       <- mkTable2 n
  (mbr,mbrM)       <- mkTable2 n
  (mbr1,mbr1M)     <- mkTable2 n
  (extern,externM) <- mkTable2With 0 n
  forM_ [n,n-1 .. 0] $ \i -> forM_ [i,i+1 .. n] $ \j -> do

    -- single NCM calculation, together with multibranched loops / large
    -- interior loops
    --
    -- TODO check if large interior loops are ok?! Do we even want those, or
    -- should they come solely from NCMs? Have an option for that
    let sncmIJ = fncmSingle db inp constr i j
    let multiIJ = vuminimumP $ fMulti db inp constr i j mbr mbr1
    let interiorIJ = vuminimumP $ fInterior db inp constr i j tmp
    writeM sncmM (i,j) $ minimum [sncmIJ, multiIJ, interiorIJ]
    -- double NCM calculation
    forM (zip3 dncms dncmMs $ VU.toList knownDoubleNCM) $ \(dncm, dncmM, ((di,dj),_)) -> do
      let k = i+di-1
      let l = j-dj+1
      -- (1) Begin a stem (dNCM follows sNCM)
      let dsIJ = fncmDS db inp constr i j k l sncm
      -- (2) continue a stem (dNCM follows dNCM)
      let ddIJ = vuminimum . VU.map snd $ fncmDD db inp constr i j k l dncms
      writeM dncmM (i,j) $ dsIJ `min` ddIJ

    -- fill helper table which makes mbr calculations ~2x faster
    writeM tmpM (i,j) $ minimum $ map (!(i,j)) dncms

    -- fill mbr table
    let upIJ = if (i+1<j) then mbr!(i,j-1) else eInf
    let stemIJ = vuminimumP $ fMStem i j tmp
    let mbrstemIJ = if (i+2<j) then vuminimum . VU.map snd $ fMMbrStem i j mbr tmp else eInf
    writeM mbrM (i,j) $ minimum [upIJ,stemIJ,mbrstemIJ]

    -- fill mbr1 table
    let stem1IJ = tmp!(i,j) -- just a mnemonic
    let up1IJ = if (i+1<j) then mbr1!(i,j-1) else eInf
    writeM mbr1M (i,j) $ stem1IJ `min` up1IJ

  -- fill extern table
  let j=n
  forM_ [n-1,n-2..0] $ \i -> do
    let unpairedExt = extern!(i+1,j)
    let stemExtExt = if (i+2<j) then (vuminimum . VU.map sel3 $ fStemExtExt dncms extern i j) else eInf
    let stemExt = if (i+2<j) then (vuminimum . VU.map sel3 $ fStemExt dncms i j) else eInf
    writeM externM (i,j) $ minimum [unpairedExt, stemExtExt, stemExt]

  return (sncm,dncms,mbr,mbr1,extern)

-- | Backtracking suboptimal results

backtrack :: MotifDB -> Double -> Primary -> Constraint -> Tables -> [(Double,D1Secondary)]
backtrack db delta inp cns@(Constraint constr) (sncm,dncms,mbr,mbr1,extern) = outp where
  outp
    | bE >= -0.0001 = [(0,f ((n+1),[]))]
    | otherwise = {-filter ((<=0).fst) .-} map (first (bE-bonus+delta-) . second (f . ((n+1),) . sort)) $ externbt delta 0 n
    where f :: (Int,[(Int,Int)]) -> D1Secondary
          f = mkD1S
  externbt d i j =
    -- unpaired nucleotide to the left
    [ (e,x)
    | i<j-1
    , let bestE = extern ! (i+1,j)
    , let d' = extern ! (i,j) - bestE + d
    , d'>=0
    , (e,x) <- externbt d' (i+1) j
    ] ++
    -- stem
    [ (e,x)
    | i<j
    , (idx,k,bestE) <- VU.toList $ fStemExt dncms i j
    , let d' = extern!(i,j) - bestE + d
    , d'>=0
    , (e,x) <- dncmbt d' idx i k
    ] ++
    -- two or more stems
    [ (ey,x++y)
    | i<j
    , (idx,k,bestE) <- VU.toList $ fStemExtExt dncms extern i j
    , let d' = extern!(i,j) - bestE + d
    , d'>=0
    , (ex,x) <- dncmbt d' idx i k
    , ex>=0
    , (ey,y) <- externbt ex (k+1) j
    , ey>=0
    ]
  dncmbt d idx i j =
    -- D-D
    [ (e,(i,j):x)
    | let (di,dj) = fst $ knownDoubleNCM VU.! idx
    , let dncm = dncms!!idx
    , let k = i+di-1, let l = j-dj+1
    , (nidx,bestE) <- VU.toList $ fncmDD db inp cns i j k l dncms
    , let d' = dncm!(i,j) - bestE + d
    , d'>=0
    , (e,x) <- dncmbt d' nidx k l
    ] ++
    -- D-S
    [ (e,(i,j):x)
    | let (di,dj) = fst $ knownDoubleNCM VU.! idx
    , let dncm = dncms!!idx
    , let k = i+di-1, let l = j-dj+1
    , let bestE = fncmDS db inp cns i j k l sncm
    , let d' = dncm!(i,j) - bestE + d
    , d'>=0
    , (e,x) <- sncmbt d' k l
    ]
  sncmbt d i j =
    [ (d',[(i,j)])
    | j-i>=3
    , let bestE = fncmSingle db inp cns i j
    , let d' = sncm!(i,j) - bestE + d
    , d'>=0
    ] ++
    -- sNCM enclosing a multibranched loop
    [ (ey,(i,j):x++y)
    | j-i>3
    , (k,bestE) <- VU.toList $ fMulti db inp cns i j mbr mbr1
    , let d' = sncm!(i,j) - bestE + d
    , d'>=0
    , (ex,x) <- mbrbt d' (i+1) k
    , ex>=0
    , (ey,y) <- mbr1bt ex (k+1) (j-1)
    ] ++
    -- sNCM inclosing an interior loop
    [ (e,(i,j):x)
    | idx <- [0 .. VU.length knownDoubleNCM -1]
    , let dncm = dncms!!idx
    , ((k,l),bestE) <- VU.toList $ fInterior db inp cns i j dncm
    , let d' = sncm!(i,j) - bestE + d
    , d'>=0
    , (e,x) <- dncmbt d' idx k l
    ]
    -- TODO enclosing an interior loop
  mbrbt d i j =
    -- unpaired to the right
    [ (x,z)
    | i+1<j
    , let bestE = mbr!(i,j-1)
    , let d' = mbr!(i,j) - bestE + d
    , d'>=0
    , (x,z) <- mbrbt d' i (j-1)
    ] ++
    -- a stem at k,j
    [ (x,z)
    | i+1<j
    , idx <- [0..VU.length knownDoubleNCM -1]
    , let dncm = dncms!!idx
    , (k,bestE) <- VU.toList $ fMStem i j dncm
    , let d' = mbr!(i,j) - bestE + d
    , d'>=0
    , (x,z) <- dncmbt d' idx k j
    ] ++
    -- two or more stems
    [ (ey,x++y)
    | i+2<j
    , idx <- [0..VU.length knownDoubleNCM -1]
    , let dncm = dncms!!idx
    , (k,bestE) <- VU.toList $ fMMbrStem i j mbr dncm
    , let d' = mbr!(i,j) - bestE + d
    , d'>=0
    , (ex,x) <- mbrbt d' i k
    , ex>=0
    , (ey,y) <- dncmbt ex idx (k+1) j
    ]
  mbr1bt d i j =
    -- add a stem at i j
    [ (x,z)
    | i+1<j
    , idx <- [0..VU.length knownDoubleNCM -1]
    , let dncm = dncms!!idx
    , let bestE = dncm!(i,j)
    , let d' = mbr1!(i,j) - bestE + d
    , d'>=0
    , (x,z) <- dncmbt d' idx i j
    ] ++
    -- unpaired to the right
    [ (x,z)
    | i+1<j
    , let bestE = mbr1!(i,j-1)
    , let d' = mbr1!(i,j) - bestE + d
    , d'>=0
    , (x,z) <- mbr1bt d' i (j-1)
    ]
  n = VU.length inp -1
  bE = extern ! (0,n)
  -- bonus = bonusScore * (fromIntegral . VU.length . VU.filter (`VU.elem` bonusCC) . VU.map fst $ constr)
  -- set bonus score from constraints so that if a constraint couldn't be met, we still allow the result
  bonus = (fromIntegral . round $ bE / bonusScore) * bonusScore



-- * Combining NCMs

-- | singleNCM insertion

fncmSingle :: MotifDB -> Primary -> Constraint -> Int -> Int -> Double
fncmSingle MotifDB{..} inp cns@(Constraint constr) i j
  | (fst $ constr `VU.unsafeIndex` i) == 'x' || (fst $ constr `VU.unsafeIndex` j) == 'x' = eInf
  | l<4 = eInf
  | otherwise = bonus + (maybe 0 ((!ci) . snd) $ find ((l==).fst) $ sCycles)
  where
    ci = mkHashedPrimary (minExtended,maxExtended) $ VU.slice i l inp
    l = j-i+1
    bonus = giveBonus cns i j
{-# INLINE fncmSingle #-}

-- | double NCM extend single NCM. We do not care that this is comparatively
-- slow as it is called only a few times, anyway.
--
-- TODO make faster (better lookup system)
-- TODO otherwise case
-- TODO eats another ~10% performance

fncmDS :: MotifDB -> Primary -> Constraint -> Int -> Int -> Int -> Int -> Table2 -> Double
fncmDS MotifDB{..} inp cns@(Constraint constr) i j k l sncm
  | (fst $ constr `VU.unsafeIndex` i) == 'x' || (fst $ constr `VU.unsafeIndex` j) == 'x' = eInf
  | k>=l = eInf
  -- in case we have one of the three known single NCMs
  | Just hinge <- ((di,dj),len) `lookup` dsConnect
  , Just ncm   <- (di,dj) `lookup` dCycles
  = sncm!(k,l) + hinge!(inp `VU.unsafeIndex` k, inp `VU.unsafeIndex` l) + ncm!ci + bonus
--  | i==1 && j==7 && di==2 && dj==2 = error $ show (di,dj,len)
  -- the single NCM is larger
  | otherwise = sncm!(k,l) + bonus
  where
    di  = k-i+1
    dj  = j-l+1
    len = l-k+1
    ci  = mkHashedPrimary (minExtended,maxExtended) $ VU.slice i di inp VU.++ VU.slice l dj inp
    bonus = giveBonus cns i j
{-# INLINE fncmDS #-}

-- | double NCM extending another double NCM.
--
-- TODO this one could profit from performance improvements. But check first vs. multibranch timings
-- TODO remove otherwise case
-- TODO improve performance, eats ~66% of total time
-- TODO improve: return empty vector on error, write special minimum function that has eInf on empty

fncmDD :: MotifDB -> Primary -> Constraint -> Int -> Int -> Int -> Int -> [Table2] -> VU.Vector (Int,Double)
fncmDD MotifDB{..} inp cns@(Constraint constr) i j k l dncms
  | (fst $ constr `VU.unsafeIndex` i) == 'x' || (fst $ constr `VU.unsafeIndex` j) == 'x' = VU.empty
  | otherwise = VU.fromList $ zipWith3 f (map fst $ VU.toList knownDoubleNCM) [0..] dncms
  where
    bonus = giveBonus cns i j
    f (dk,dl) idx dncm
      | k+2>=l = (-1,eInf)
      | Just hinge <- ((di,dj),(dk,dl)) `lookup` ddConnect
      , Just ncm   <- (di,dj) `lookup` dCycles
      = (idx,dncm!(k,l) + hinge!(inp `VU.unsafeIndex` k, inp `VU.unsafeIndex` l) + ncm!ci + bonus)
      | otherwise = (-1,eInf)
      where
        di = k-i+1
        dj = j-l+1
        ci = mkHashedPrimary (minExtended,maxExtended) $ VU.slice i di inp VU.++ VU.slice l dj inp
{-# INLINE fncmDD #-}

-- | Add one stem for "external" calculations
--
-- TODO make efficient

fStemExt :: [Table2] -> Int -> Int -> VU.Vector (Int,Int,Double) -- (dncms!!,k,value)
fStemExt dncms i j = VU.fromList xs where
  xs = [ (d,k,dncm!(i,k))
       | (d,dncm) <- zip [0..] dncms
       , k <- [i+1..j]
       ]
{-# INLINE fStemExt #-}

-- | Combine stems for "external" calculations

fStemExtExt :: [Table2] -> Table2 -> Int -> Int -> VU.Vector (Int,Int,Double)
fStemExtExt dncms extern i j = VU.fromList xs where
  xs = [ (d,k,dncm!(i,k) + extern!(k+1,j))
       | (d,dncm) <- zip [0..] dncms
       , k <- [i+1..j-1]
       ]
{-# INLINE fStemExtExt #-}

-- | Close a multibranched loop with a singleNCM
--
-- TODO close with singleNCM

fMulti :: MotifDB -> Primary -> Constraint -> Int -> Int -> Table2 -> Table2 -> VU.Vector (Int,Double)
fMulti db inp cns@(Constraint constr) i j mbr mbr1
  | (fst $ constr `VU.unsafeIndex` i) == 'x' || (fst $ constr `VU.unsafeIndex` j) == 'x' = VU.empty
  | otherwise = xs
  where
    xs = VU.map (\k -> (k, mbr!(i+1,k) + mbr1!(k+1,j-1) + bonus)) $ (VU.enumFromN (i+1) (j-i-1))
    bonus = giveBonus cns i j
{-# INLINE fMulti #-}

-- | Connect a partial multibranched structure with a hairpin. Note that the
-- "dncm" hairpin part wants one table, not the list of all tables.

fMMbrStem :: Int -> Int -> Table2 -> Table2 -> VU.Vector (Int,Double)
fMMbrStem i j mbr dncm = VU.map (\k -> (k,mbr!(i,k) + dncm!(k+1,j))) $ VU.enumFromN (i+1) (j-i-1)
{-# INLINE fMMbrStem #-}

-- | Add the first stem

fMStem :: Int -> Int -> Table2 -> VU.Vector (Int,Double)
fMStem i j dncm = VU.map (\k -> (k,dncm!(k,j))) $ VU.enumFromN i (j-i)
{-# INLINE fMStem #-}

-- | Interior loops of some size
--
-- TODO this could profit from a log-based scoring function

fInterior :: MotifDB -> Primary -> Constraint -> Int -> Int -> Table2 -> VU.Vector ((Int,Int),Double)
fInterior MotifDB{..} inp cns@(Constraint constr) i j dncm
  | (fst $ constr `VU.unsafeIndex` i) == 'x' || (fst $ constr `VU.unsafeIndex` j) == 'x' = VU.empty
  | otherwise = res
  where
    bonus = giveBonus cns i j
    res = VU.map (\(k,l) -> ((k,l),dncm!(k,l) + bonus)) . VU.filter (\(k,l) -> k<l && (i+1/=k || j-1/=l)) $ VU.unfoldr f (i+1,j-1)
    f (k,l)
      | i+10< k   = Nothing
      | j-10==l   = Just ((k,l),(k+1,j-1)) -- next element, new seed
      | otherwise = Just ((k,l),(k  ,l-1))
    {-# INLINE f #-}
{-# INLINE fInterior #-}



-- * Helper functions

-- | The default two-dim table

mkTable2 n = mkTable2With eInf n

-- | Create a 2-dim table

mkTable2With v n = do
  tM <- fromAssocsM (0,0) (n,n) v []
  t <- unsafeFreezeM tM
  return (t,tM)

-- | minimum for unboxed vectors, capturing possible 0-length. Vectors in this
-- algorithm \always\ carry positional information.
--
-- TODO the next version of the vector library should handle "VU.map snd"
-- better

vuminimumP xs = VU.foldl' (\a (_,x) -> min a x) eInf xs
{-# INLINE vuminimumP #-}

vuminimum xs = VU.foldl' (\a x -> min a x) eInf xs
{-# INLINE vuminimum #-}



-- * types and constants

type Table2 = PrimArray (Int,Int) Double
type Table4 = PrimArray (Int,Int,Int,Int) Double
type Tables = (Table2,[Table2],Table2,Table2,Table2)

bonusScore :: Double
bonusScore = (-10)^5

-- | Give a certain 'bonusScore' for the constraints that have been fulfilled.
--
-- TODO should we be more lenient with constraints that would increase the
-- total energy?

giveBonus :: Constraint -> Int -> Int -> Double
giveBonus (Constraint constr) i j
  | any (=='x') [c,d] = eInf
  | c == '(' && d == ')'
  , k==j && l==i = bonusScore * 2
  | c == '(' || d == ')' = eInf
  | c == ')' || d == '(' = eInf
  | c == '>' && k>i = bonusScore
  | c == '>' = eInf
  | d == '<' && l<j && l>=0 = bonusScore
  | d == '<' = eInf
  | all (`VU.elem` bonusCC) [c,d] = bonusScore * 2
  | any (`VU.elem` bonusCC) [c,d] = bonusScore
  | otherwise = 0
  where
    (c,k) = constr `VU.unsafeIndex` i
    (d,l) = constr `VU.unsafeIndex` j