--
-- Stores Infernal Covariance Model data. Built for ease of use, not speed but
-- should work reasonably well.
--

-- TODO Data.Vector?
--
-- TODO use generics?
--
-- TODO functions to change to probabilities!
--
-- TODO cmCanonize function!
--
-- TODO put functions into their own modules, a bit of cleanup
--
-- TODO add functions to insert a new node between two already existing nodes; think about how to handle BIF
--
-- TODO add ability to remove node; think about how to handle BIF
--
-- NOTE maybe BIF should not be insertable/removable right now?

module Biobase.Infernal.CM where


import Data.Array.IArray
import Data.List (genericLength)

import Biobase.RNA hiding (nucE) -- "E type" nucleotides do not happen in CMs!



-- * Data types for Covariance Models

-- {{{ Data types

-- | A complete covariance model. Each node and each state can be tagged with
-- additional data. Typically, say after parsing, the tag will be ().

data CM n s = CM
  { nodes      :: Array Int (Node n)
  , states     :: Array Int (State s)
  , header     :: [(String,String)] -- keeps the list of header entries sorted!
  , localBegin :: Array Int Double
  , localEnd   :: Array Int Double
  , cmType     :: CMType
  , nullModel  :: Array Nucleotide Double
  } deriving (Show)

-- | Describes one node

data Node n = Node
  { nid :: Int
  , ntype :: NodeType
  , nparents :: [Int] -- TODO can there be more than one?
  , nchildren :: [Int]
  , nstates :: [Int]
  , ntag :: n
  } deriving (Show)

-- | One state

data State s = State
  { sid :: Int
  , stype :: StateType
  , snode :: Int
  , sparents :: [Int]
  , schildren :: [Transition]
  , semission :: [Emission]
  , stag :: s
  } deriving (Show)

-- | CMType is important if we want to set localBegin / localEnd!

data CMType = CMProb | CMScore
  deriving (Show,Eq)

-- | can emit either one nucleotide or a pair

data Emission
  = EmitS {eNuc :: Nucleotide, escore :: Double}
  | EmitP { eNucL :: Nucleotide, eNucR :: Nucleotide, escore :: Double}
  deriving (Show)

-- | branches are transition without attached probability becaue both branches are always taken

data Transition
  = Branch {tchild :: Int}
  | Transition {tchild :: Int, tscore :: Double}
  deriving (Show)

-- | the different node types

data NodeType = MATP | MATL | MATR | BIF | ROOT | BEGL | BEGR | END
  deriving (Read,Show,Eq,Ord,Enum,Bounded)

-- | the different state types

data StateType = MP | IL | IR | D | ML | MR | B | S | E
  deriving (Read,Show,Eq,Ord,Enum,Bounded)

-- }}}


-- * make a local model out of a global one

-- | generate a local model with local begin prob and local end prob

cmMakeLocal :: Double -> Double -> CM n s -> CM n s
cmMakeLocal pbegin pend cm = cmMakeLocalBegin pbegin $ cmMakeLocalEnd pend cm



cmMakeLocalBegin :: Double -> CM n s -> CM n s
cmMakeLocalBegin pbegin cm = cm{localBegin = localBegin cm // changes} where
  changes = rootS : (start : intern)
  rootS = (0, prob2Score 0 1.0) -- root disabled!
  start = (head . nstates $ nodes cm ! 1, prob2Score (1-pbegin) 1.0) -- the first state after "root 0"
  intern = map (\k -> (sid $ nodeMainState cm k,prob2Score (pbegin / l) 1.0)) nds
  nds = filter (localBeginPossible cm) . elems $ nodes cm
  l = genericLength nds



-- TODO have to change the transition score, too!

cmMakeLocalEnd :: Double -> CM n s -> CM n s
cmMakeLocalEnd pend cm = cm{localEnd = localEnd cm // changes} where
  changes = map (\k -> (sid $ nodeMainState cm k,prob2Score (pend / l) 1.0)) nds
  nds = filter (localBeginPossible cm) . elems $ nodes cm
  l = genericLength nds



-- * Transform between score and probability mode

-- | given a CM in score mode, change it to probability mode

cmScore2Prob :: CM n s -> CM n s
cmScore2Prob cm' = if cmType cm' == CMProb then cm' else CM
  (nodes  cm)
  (statesScore2Prob cm $ states cm)
  (header cm)
  (localBeginScore2Prob $ localBegin cm)
  (localEndScore2Prob   $ localEnd   cm)
  CMProb
  nm
  where
    nm = amap (flip score2Prob 0.25) $ nullModel cm'
    cm = cm' {nullModel = nm}


-- | Given a CM in prob mode, change to score mode

cmProb2Score :: CM n s -> CM n s
cmProb2Score cm' = if cmType cm' == CMScore then cm' else CM
  (nodes  cm)
  (statesProb2Score cm $ states cm)
  (header cm)
  (localBeginProb2Score $ localBegin cm)
  (localEndProb2Score   $ localEnd   cm)
  CMScore
  nm
  where
    nm = amap (flip prob2Score 0.25) $ nullModel cm'
    cm = cm' {nullModel = nm}



-- | normalize all PROBabilities in a CM

cmNormalizeProbabilities :: CM n s -> CM n s
cmNormalizeProbabilities cm
  | cmType cm == CMScore = error "cannot normalize score-type CM"
  | otherwise            = cm -- TODO have to map normalization over all scores!



-- {{{ CM score/prob conversion helpers

statesScore2Prob :: CM n s -> Array Int (State s) -> Array Int (State s)
statesScore2Prob cm sA = amap f sA where
  f s = s {schildren = map fT $ schildren s, semission = map fE $ semission s}
  fT b@(Branch _) = b
  fT (Transition k v) = Transition k (score2Prob v 1.0)
  fE (EmitS k v) = EmitS k (score2Prob v $ nullModel cm ! k)
  fE (EmitP k1 k2 v) = EmitP k1 k2 (score2Prob v $ (nullModel cm ! k1) * (nullModel cm ! k2))



localBeginScore2Prob :: Array Int Double -> Array Int Double
localBeginScore2Prob sA = amap f sA where
  f s = score2Prob s 1.0



localEndScore2Prob :: Array Int Double -> Array Int Double
localEndScore2Prob sA = amap f sA where
  f s = score2Prob s 1.0



statesProb2Score :: CM n s -> Array Int (State s) -> Array Int (State s)
statesProb2Score cm sA = amap f sA where
  f s = s {schildren = map fT $ schildren s, semission = map fE $ semission s}
  fT b@(Branch _) = b
  fT (Transition k v) = Transition k (prob2Score v 1.0)
  fE (EmitS k v) = EmitS k (prob2Score v $ nullModel cm ! k)
  fE (EmitP k1 k2 v) = EmitP k1 k2 (prob2Score v $ (nullModel cm ! k1) * (nullModel cm ! k2))



localBeginProb2Score :: Array Int Double -> Array Int Double
localBeginProb2Score sA = amap f sA where
  f s = prob2Score s 1.0



localEndProb2Score :: Array Int Double -> Array Int Double
localEndProb2Score sA = amap f sA where
  f s = score2Prob s 1.0

-- }}}




-- * Helper Functions

-- {{{ helper functions

-- | extract the main state for each node (eg MP state for MATP node)

-- TODO shouldn't this just be "head $ nstates n"?

nodeMainState :: CM n s -> Node n -> State s
nodeMainState cm n = head $ filter ((==st) . stype) ss where
  (Just st) = (ntype n) `lookup` nodeMainStateAssocs
  ss = map (states cm !) $ nstates n


-- | Checks for each node, if it can be target of a local begin.

localBeginPossible :: CM n s -> Node n -> Bool
localBeginPossible cm n =
  if ntype n `elem` okNodes
  && (not . any (==0) $ nparents n) -- nodes reachable from "root" (that is: node 1) have handled specially
    then True
    else False
  where
    okNodes = [MATP,MATL,MATR,BIF]



-- | Checks for each node if it can lead to a local end.

localEndPossible :: CM n s -> Node n -> Bool
localEndPossible cm n =
  if ntype n `elem` okNodes
  && (END /= (ntype $ nodes cm ! (nid n +1)))
    then True
    else False
  where
    okNodes = [MATP,MATL,MATR,BEGL,BEGR]



-- | transform scores into probabilities, given a nullmodel for x

-- TODO quickcheck!

score2Prob x null
  | x == (-1/0) = 0
  | otherwise   = exp (x * log 2) * null



-- | back into scores

prob2Score x null
  | x == 0    = (-1/0)
  | otherwise = log (x / null) / log 2



-- | Transform a state, setting probabilities instead of scores. Requires CM
-- knowledge for background model.

-- TODO actually use the background model

stateScore2Prob :: CM n s -> State s -> State s
stateScore2Prob cm s = error "implement me"



-- | Transform a state, setting scores instead of probabilities.

stateProb2Score :: CM n s -> State s -> State s
stateProb2Score cm s = error "implement me"
transitionTargets :: [Transition] -> [Int]
transitionTargets xs = map f xs where
  f (Branch x)       = x
  f (Transition x _) = x


nodeMainStateAssocs :: [(NodeType,StateType)]
nodeMainStateAssocs =
  [ (MATP, MP)
  , (MATL, ML)
  , (MATR, MR)
  , (BIF,  B)
  , (ROOT, S)
  , (BEGL, S)
  , (BEGR, S)
  , (END,  E)
  ]

-- }}}



-- TODO score -> prob : exp (x / log 2) -- CHECK THIS!!!
-- score -> prob : exp (x * log 2)
-- prob -> score : (log x) / log 2
-- TODO mkGlobal (needed?)
-- TODO use logsum : log (exp x + exp y) = x + log (1 + exp (y-x)), where x>y