{-# LANGUAGE TypeFamilies #-}
{- |
This module provides a simple way to train
the transition matrix and initial probability vector
using simple patterns of state sequences.

You may create a trained model using semigroup combinators like this:

> example :: HMM.DiscreteTrained Char (ShapeStatic.ZeroBased TypeNum.U2) Double
> example =
>    let a = atom FL.i0
>        b = atom FL.i1
>        distr =
>           Distr.DiscreteTrained $ Map.fromList $
>           ('a', ShapeStatic.vector $ 1!:2!:FL.end) :
>           ('b', ShapeStatic.vector $ 4!:3!:FL.end) :
>           ('c', ShapeStatic.vector $ 0!:1!:FL.end) :
>           []
>    in finish (ShapeStatic.ZeroBased Proxy) distr $
>       replicate 5 $ replicate 10 a <> replicate 20 b
-}
module Math.HiddenMarkovModel.Pattern (
   T,
   atom,
   append,
   replicate,
   finish,
   ) where

import qualified Math.HiddenMarkovModel.Distribution as Distr
import qualified Math.HiddenMarkovModel as HMM
import Math.HiddenMarkovModel.Private (Trained(..))
import Math.HiddenMarkovModel.Utility (SquareMatrix, squareConstant)

import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.ShapeStatic as ShapeStatic

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as StorableArray
import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.FixedLength as FL
import Data.FixedLength ((!:))

import qualified Type.Data.Num.Unary.Literal as TypeNum
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Map as Map
import Data.Semigroup (Semigroup, (<>), stimes)

import Prelude hiding (replicate)


newtype T sh prob =
   Cons (sh -> (Shape.Index sh, SquareMatrix sh prob, Shape.Index sh))

atom ::
   (Shape.Indexed sh, Shape.Index sh ~ state, Class.Real prob) =>
   state -> T sh prob
atom s = Cons $ \sh -> (s, squareConstant sh 0, s)


instance
   (Shape.Indexed sh, Eq sh, Class.Real prob) =>
      Semigroup (T sh prob) where
   (<>) = append
   stimes k = replicate $ fromIntegral k


infixl 5 `append`

append ::
   (Shape.Indexed sh, Eq sh, Class.Real prob) =>
   T sh prob -> T sh prob -> T sh prob
append (Cons f) (Cons g) =
   Cons $ \n ->
      case (f n, g n) of
         ((sai, ma, sao), (sbi, mb, sbo)) ->
            (sai, increment (sbi,sao) 1 $ Vector.add ma mb, sbo)

replicate ::
   (Shape.Indexed sh, Class.Real prob) => Int -> T sh prob -> T sh prob
replicate ki (Cons f) =
   Cons $ \sh ->
      case f sh of
         (si, m, so) ->
            let k = fromIntegral ki
            in  (si, increment (si,so) (k-1) $ Vector.scale k m, so)

increment ::
   (Shape.Indexed sh, Shape.Index sh ~ state, Class.Real a) =>
   (state, state) -> a -> SquareMatrix sh a -> SquareMatrix sh a
increment (i,j) x m  =  StorableArray.accumulate (+) m [((i,j), x)]


finish ::
   (Shape.Indexed sh, Class.Real prob) =>
   sh -> tdistr -> T sh prob -> Trained tdistr sh prob
finish sh tdistr (Cons f) =
   case f sh of
      (si, m, _so) ->
         Trained {
            trainedInitial = StorableArray.fromAssociations sh 0 [(si,1)],
            trainedTransition = m,
            trainedDistribution = tdistr
         }


_example :: HMM.DiscreteTrained Char (ShapeStatic.ZeroBased TypeNum.U2) Double
_example =
   let a = atom FL.i0
       b = atom FL.i1
       distr =
          Distr.DiscreteTrained $ Map.fromList $
          ('a', ShapeStatic.vector $ 1!:2!:FL.end) :
          ('b', ShapeStatic.vector $ 4!:3!:FL.end) :
          ('c', ShapeStatic.vector $ 0!:1!:FL.end) :
          []
   in finish (ShapeStatic.ZeroBased Proxy) distr $
      replicate 5 $ replicate 10 a <> replicate 20 b