{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Math.HiddenMarkovModel.Named (
   T(..),
   Discrete,
   Gaussian,
   fromModelAndNames,
   toCSV,
   fromCSV,
   ) where

import qualified Math.HiddenMarkovModel.Distribution as Distr
import qualified Math.HiddenMarkovModel.Private as HMM
import qualified Math.HiddenMarkovModel.CSV as HMMCSV
import Math.HiddenMarkovModel.Utility (attachOnes, vectorDim)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as StorableArray
import qualified Data.Array.Comfort.Boxed as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Boxed (Array)

import qualified Text.CSV.Lazy.String as CSV
import Text.Printf (printf)

import qualified Control.Monad.Exception.Synchronous as ME
import qualified Control.Monad.Trans.State as MS
import Control.DeepSeq (NFData, rnf)
import Foreign.Storable (Storable)

import qualified Data.Map as Map
import qualified Data.List as List
import Data.Tuple.HT (swap)
import Data.Map (Map)


{- |
A Hidden Markov Model with names for each state.

Although 'nameFromStateMap' and 'stateFromNameMap' are exported
you must be careful to keep them consistent when you alter them.
-}
data T distr sh ix prob =
   Cons {
      model :: HMM.T distr sh prob,
      nameFromStateMap :: Array sh String,
      stateFromNameMap :: Map String ix
   }
   deriving (Show)

type Discrete symbol stateSh prob =
      T (Distr.Discrete symbol stateSh prob) stateSh (Shape.Index stateSh) prob
type Gaussian emiSh stateSh a =
      T (Distr.Gaussian emiSh stateSh a) stateSh (Shape.Index stateSh) a


instance
   (NFData distr, NFData sh, NFData ix, NFData prob,
    Shape.C sh, Storable prob) =>
      NFData (T distr sh ix prob) where
   rnf hmm = rnf (model hmm, nameFromStateMap hmm, stateFromNameMap hmm)


fromModelAndNames ::
   (Shape.Indexed sh, Shape.Index sh ~ state) =>
   HMM.T distr sh prob -> [String] -> T distr sh state prob
fromModelAndNames md names =
   let m = Array.fromList (StorableArray.shape $ HMM.initial md) names
   in  Cons {
          model = md,
          nameFromStateMap = m,
          stateFromNameMap = inverseMap m
       }

inverseMap ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) => Array sh String -> Map String ix
inverseMap =
   Map.fromListWith (error "duplicate label") .
   map swap . Array.toAssociations


toCSV ::
   (Distr.ToCSV distr, Shape.Indexed sh, Class.Real prob, Show prob) =>
   T distr sh ix prob -> String
toCSV hmm =
   CSV.ppCSVTable $ snd $ CSV.toCSVTable $ HMMCSV.padTable "" $
      Array.toList (nameFromStateMap hmm) : HMM.toCells (model hmm)

fromCSV ::
   (Distr.FromCSV distr, Distr.StateShape distr ~ stateSh,
    Shape.Indexed stateSh, Shape.Index stateSh ~ state,
    Class.Real prob, Read prob) =>
   (Int -> stateSh) ->
   String -> ME.Exceptional String (T distr stateSh state prob)
fromCSV makeShape =
   MS.evalStateT (parseCSV makeShape) . map HMMCSV.fixShortRow . CSV.parseCSV

parseCSV ::
   (Distr.FromCSV distr, Distr.StateShape distr ~ stateSh,
    Shape.Indexed stateSh, Shape.Index stateSh ~ state,
    Class.Real prob, Read prob) =>
   (Int -> stateSh) -> HMMCSV.CSVParser (T distr stateSh state prob)
parseCSV makeShape = do
   names <- HMMCSV.parseStringList =<< HMMCSV.getRow
   let duplicateNames =
         Map.keys $ Map.filter (> (1::Int)) $
         Map.fromListWith (+) $ attachOnes names
    in HMMCSV.assert (null duplicateNames) $
          "duplicate names: " ++ List.intercalate ", " duplicateNames
   md <- HMM.parseCSV makeShape
   let n = length names
       m = vectorDim (HMM.initial md)
    in HMMCSV.assert (n == m) $
          printf "got %d state names for %d states" n m
   return $ fromModelAndNames md names