{-# LANGUAGE RebindableSyntax #-} module HiddenMarkovModel where import qualified LabelChain import qualified Label import qualified Named import qualified Math.HiddenMarkovModel.Named as HMMNamed import qualified Math.HiddenMarkovModel as HMM import qualified Numeric.Container as NC import qualified Data.Packed.Matrix as Matrix import qualified Data.Packed.Vector as Vector import Data.Packed.Vector (Vector) import qualified Data.StorableVector.Lazy as SVL import Foreign.Storable (Storable) import Text.Printf (printf, ) import qualified Options.Applicative as OP import qualified System.Path.PartClass as PathClass import qualified System.Path as Path import qualified Control.Monad.Exception.Synchronous as ME import qualified Control.Parallel.Strategies as Par import qualified Data.NonEmpty.Class as NonEmptyC import qualified Data.NonEmpty as NonEmpty import qualified Data.Monoid.HT as Mn import qualified Data.List.HT as ListHT import qualified Data.List as List import qualified Data.Map as Map; import Data.Map (Map) import qualified Data.Set as Set; import Data.Set (Set) import Data.Traversable (Traversable) import Data.Foldable (foldMap) import Data.Monoid ((<>)) import Data.NonEmpty ((!:)) import Data.Tuple.HT (swap) import NumericPrelude.Numeric import NumericPrelude.Base allStates :: [String] allStates = List.sort [Label.clickBegin, Label.clickEnd, Label.chirpingMain, Label.chirpingPause, Label.growling, Label.pause] admissibleTransitions :: [(String, [String])] admissibleTransitions = (Label.pause, [Label.pause, Label.chirpingMain, Label.clickBegin, Label.growlingClickBegin]) : (Label.clickBegin, [Label.clickBegin, Label.clickEnd]) : (Label.clickEnd, [Label.clickBegin, Label.clickEnd, Label.chirpingMain, Label.growlingClickBegin, Label.pause]) : (Label.chirpingMain, [Label.chirpingMain, Label.chirpingPause]) : (Label.chirpingPause, [Label.chirpingMain, Label.chirpingPause, Label.clickBegin, Label.growlingClickBegin, Label.pause]) : (Label.growlingClickBegin, [Label.growlingClickBegin, Label.growlingClickEnd]) : (Label.growlingClickEnd, [Label.growlingClickBegin, Label.growlingClickEnd, Label.chirpingMain, Label.clickBegin, Label.pause]) : [] admissibleTransitionSet :: Set (String, String) admissibleTransitionSet = foldMap (\(from, tos) -> Set.fromList $ map ((,) from) tos) admissibleTransitions forbiddenTransitions :: Set (String, String) -> Map HMM.State String -> HMM.GaussianTrained Double -> Set (String, String) forbiddenTransitions admissible dict = flip Set.difference admissible . foldMap (foldMap (\(row, (col, x)) -> Mn.when (x > 0) $ Set.singleton (checkedLookup dict (HMM.state col), checkedLookup dict (HMM.state row)))) . zipWith (\k -> map ((,) k) . zip [0..]) [0..] . Matrix.toLists . HMM.trainedTransition inverseMap :: Map HMM.State String -> Map String HMM.State inverseMap = Map.fromListWith (error "duplicate label") . map swap . Map.toList checkedLookup :: (Ord k, Show k) => Map k a -> k -> a checkedLookup m k = Map.findWithDefault (error $ "checkedLookup: unknown key " ++ show k) k m mapsFromLabels :: [String] -> (Map String HMM.State, Map HMM.State String) mapsFromLabels ss = let m = Map.fromList $ zip (map HMM.state [0..]) ss in (inverseMap m, m) checkNonEmpty :: (PathClass.AbsRel ar) => Path.File ar -> Named.Signal -> ME.Exceptional String Named.NonEmptySignal checkNonEmpty path (Named.Cons name sig) = case SVL.viewL sig of Nothing -> ME.throw $ printf "%s: %s: empty feature signal" (Path.toString path) name Just (x,xs) -> return $ Named.Cons name $ x !: xs flattenStorableVectorLazy :: (Storable a) => NonEmpty.T SVL.Vector a -> SVL.Vector a flattenStorableVectorLazy (NonEmpty.Cons x xs) = SVL.cons x xs prepare :: [Named.NonEmptySignal] -> NonEmpty.T [] (Vector Double) prepare nxs = let xs = map Named.body nxs vecFromList = NC.cmap realToFrac . Vector.fromList in (vecFromList $ map NonEmpty.head xs) !: (map vecFromList $ List.transpose $ map (SVL.unpack . NonEmpty.tail) xs) label :: HMM.Gaussian Double -> [Named.NonEmptySignal] -> [HMM.State] label model = NonEmpty.flatten . HMM.reveal model . prepare analyze :: HMMNamed.Gaussian Double -> [Named.NonEmptySignal] -> LabelChain.T Int String analyze model = fmap (checkedLookup $ HMMNamed.nameFromStateMap model) . LabelChain.segment . label (HMMNamed.model model) flattenIntervals :: Map String HMM.State -> LabelChain.T Int String -> [HMM.State] flattenIntervals dict = LabelChain.flattenLabels . fmap (checkedLookup dict) trainSupervised :: (PathClass.AbsRel ar) => Map String HMM.State -> Path.File ar -> [Named.NonEmptySignal] -> LabelChain.T Int String -> ME.Exceptional String (HMM.GaussianTrained Double) trainSupervised dict input sig labels = do labelSig <- ME.fromMaybe (printf "%s: no labels for supervised training" $ Path.toString input) $ NonEmpty.fetch $ flattenIntervals dict labels return $ HMM.trainSupervised (Map.size dict) $ NonEmptyC.zip labelSig (prepare sig) trainMany :: (Traversable f) => (trainingData -> HMM.GaussianTrained Double) -> NonEmpty.T f trainingData -> HMM.Gaussian Double trainMany train = HMM.finishTraining . NonEmpty.foldl1 HMM.mergeTrained . Par.withStrategy (Par.parTraversable Par.rdeepseq) . fmap train data Convergence = Convergence { cvgMaxIter, cvgSubIter :: Int, cvgTolerance :: Double } convergenceOptions :: OP.Parser Convergence convergenceOptions = OP.liftA3 Convergence (OP.option OP.auto $ OP.value 100 <> OP.long "max-iterations" <> OP.metavar "NUMBER" <> OP.help "maximal number of iterations for unsupervised training") (OP.option OP.auto $ OP.value 10 <> OP.long "sub-iterations" <> OP.metavar "NUMBER" <> OP.help "number of sub-iterations per iteration") (OP.option OP.auto $ OP.value 1e-5 <> OP.long "tolerance" <> OP.metavar "PROB" <> OP.help "convergence tolerance for unsupervised training") takeUntilConvergence :: Convergence -> [HMM.Gaussian Double] -> [HMM.Gaussian Double] takeUntilConvergence opt = (\(hmm:hmms) -> (hmm :) $ map snd . take (cvgMaxIter opt) . takeWhile fst $ ListHT.mapAdjacent (\hmm0 hmm1 -> (HMM.deviation hmm0 hmm1 > cvgTolerance opt, hmm1)) hmms) . ListHT.sieve (cvgSubIter opt)