{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ExplicitNamespaces #-}
module Text.HaskSeg.Counts (cleanCounts, initializeCounts, updateCounts, addCounts, subtractCounts) where
import Control.Monad.Random
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import Text.Printf (printf, PrintfArg(..), fmtPrecision, fmtChar, errorBadFormat, formatString, vFmt, IsChar)
import Control.Monad.Log
import Control.Monad.State.Class (MonadState(get, put))
import Control.Monad.Reader.Class
import Control.Monad.Reader (ReaderT)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.State.Strict
import Data.Tuple (swap)
import Data.List (unfoldr, nub, mapAccumL, intercalate, sort, foldl1')
import Text.HaskSeg.Probability (Prob, LogProb, Probability(..), showDist, sampleCategorical)
import Text.HaskSeg.Types (Locations, Morph, Counts, Site, Location(..), Lookup, showLookup, showCounts, SamplingState(..), Params(..))
cleanCounts :: Counts elem -> Counts elem
cleanCounts = Map.filter (\x -> x /= 0)
initializeCounts :: (Ord elem, Show elem) => Locations elem -> Counts elem
initializeCounts ls = Map.fromListWith (+) (Vector.toList (Vector.map (\x -> (x, 1)) words'))
where
words = Vector.unfoldr (\xs -> case span (\x -> _morphFinal x == False) xs of
([], []) -> Nothing
(xs', x:ys) -> Just (xs' ++ [x], ys)
) (Vector.toList ls)
words' = Vector.map (Vector.fromList . map _value) words
updateCounts :: (Ord elem) => (Int -> Int -> Int) -> Morph elem -> Int -> Counts elem -> Counts elem
updateCounts f w n = Map.insertWith f w n
addCounts :: (Ord elem) => Morph elem -> Int -> Counts elem -> Counts elem
addCounts = updateCounts (+)
subtractCounts :: (Ord elem) => Morph elem -> Int -> Counts elem -> Counts elem
subtractCounts = updateCounts (flip (-))