{-# language BangPatterns #-}
{-# language DeriveDataTypeable #-}
{-# language ExplicitNamespaces #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# language MagicHash #-}
{-# language NoImplicitPrelude #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language TypeFamilies #-}
module Freq.Internal
(
FreqTrain(..)
, empty
, singleton
, tabulate
, train
, trainWith
, trainWithMany
, Freq(..)
, measure
, Freaky(prob)
, prettyFreqTrain
) where
import Prelude
()
import Control.Applicative (Applicative(pure))
import Control.DeepSeq (NFData)
import Control.Monad (Monad((>>=)), (>>), forM_)
import Control.Monad.ST (ST,runST)
import Data.Binary (Binary(..))
import Data.Bool (otherwise)
import Data.ByteString.Internal (ByteString(..), w2c)
import Data.Data (Data)
import Data.Eq (Eq((==)))
import Data.Foldable (Foldable(foldMap, sum))
import Data.Function ((.), ($))
import Data.Functor (fmap)
import Data.List ((++))
import Data.Map.Strict.Internal (Map)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe)
import Data.Monoid (Monoid(mempty, mappend))
import Data.Ord (Ord(min, (<)))
import Data.Primitive.ByteArray (ByteArray,foldrByteArray)
import Data.Semigroup (Semigroup((<>)))
import Data.Set (Set)
import Data.String (String)
import Data.Word (Word8)
import GHC.Base (Double, Int(I#), build)
import GHC.Err (undefined)
import GHC.IO (FilePath, IO)
import GHC.Num ((+), (*), (-))
import GHC.Read (Read)
import GHC.Real ((/), mod)
import GHC.Show (Show(show))
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Unsafe as BU
import qualified Data.Map.Strict as DMS
import qualified Data.Primitive.ByteArray as PM
import qualified Data.Primitive.Types as PM
import qualified Data.Set as S
import qualified GHC.OldList as L
import qualified Numeric as Numeric
import qualified Prelude as P
class Freaky a where
prob :: a -> Word8 -> Word8 -> Double
measure :: Freaky a => a -> BC.ByteString -> Double
measure _ (PS _ _ 0) = 0
measure _ (PS _ _ 1) = 0
measure f !b = (go 0 0) / (P.fromIntegral l)
where
l :: Int
l = BC.length b - 1
go :: Int -> Double -> Double
go !p !acc
| p == l = acc
| otherwise =
let k = BU.unsafeIndex b p
r = BU.unsafeIndex b (p + 1)
in go (p + 1) (prob f k r + acc)
{-# INLINE measure #-}
newtype FreqTrain = FreqTrain { _getFreqTrain :: Map Word8 (Map Word8 Double) }
deriving
( Data
, Eq
, NFData
, Ord
, Read
, Show
)
instance Freaky FreqTrain where
prob (FreqTrain f) w1 w2 =
case DMS.lookup w1 f of
Nothing -> 0
Just g -> case DMS.lookup w2 g of
Nothing -> 0
Just weight -> ratio weight g
{-# INLINE prob #-}
instance Semigroup FreqTrain where
{-# INLINE (<>) #-}
(FreqTrain a) <> (FreqTrain b) = FreqTrain (union a b)
instance Monoid FreqTrain where
{-# INLINE mempty #-}
mempty = empty
{-# INLINE mappend #-}
(FreqTrain a) `mappend` (FreqTrain b) = FreqTrain (union a b)
empty :: FreqTrain
empty = FreqTrain DMS.empty
{-# INLINE empty #-}
singleton :: Word8
-> Word8
-> Double
-> FreqTrain
singleton k ka w = FreqTrain $ DMS.singleton k (DMS.singleton ka w)
{-# INLINE singleton #-}
tabulate :: FreqTrain -> Freq
tabulate = tabulateInternal
{-# INLINE tabulate #-}
train :: BC.ByteString
-> FreqTrain
train !b = tally b
{-# INLINE train #-}
trainWith :: FilePath
-> IO FreqTrain
trainWith path = BC.readFile path >>= (pure . tally)
{-# INLINE trainWith #-}
trainWithMany :: Foldable t
=> t FilePath
-> IO FreqTrain
trainWithMany paths = foldMap trainWith paths
{-# INLINE trainWithMany #-}
prettyFreqTrain :: FreqTrain -> IO ()
prettyFreqTrain (FreqTrain m)
= DMS.foldMapWithKey
(\c1 m' ->
P.putStrLn (if c1 == 10 then "\\n" else [w2c c1])
>> DMS.foldMapWithKey
(\c2 prb -> P.putStrLn (" " ++ [w2c c2] ++ " " ++ P.show (P.round prb :: Int))) m') m
data Freq = Freq
{ _Dim :: !Int
, _2d :: !ByteArray
, _Flat :: !ByteArray
}
deriving (Eq)
toList :: PM.Prim a => ByteArray -> [a]
toList xs = build (\c n -> foldrByteArray c n xs)
{-# INLINE toList #-}
toDoubles :: ByteArray -> [Double]
toDoubles = toList
toWord8s :: ByteArray -> [Word8]
toWord8s = toList
instance Binary Freq where
put (Freq dim ds ws) = put (dim,toDoubles ds,toWord8s ws)
get = do
(dim :: Int,asDoubles :: [Double],asWord8s :: [Word8]) <- get
pure (Freq dim (PM.byteArrayFromList asDoubles) (PM.byteArrayFromList asWord8s))
instance Freaky Freq where
{-# INLINE prob #-}
prob (Freq sz square ixs) chrFst chrSnd =
let !ixFst = word8ToInt (PM.indexByteArray ixs (word8ToInt chrFst))
!ixSnd = word8ToInt (PM.indexByteArray ixs (word8ToInt chrSnd))
in PM.indexByteArray square (sz * ixFst + ixSnd)
instance P.Show Freq where
show (Freq i arr ixs) =
P.show i ++ "x" ++ show i
++ "\n"
++ "\n2D Array: \n"
++ go 0
++ "\n256 Array: \n"
++ ho 0
where
ho :: Int -> String
ho !ix = if ix < PM.sizeofByteArray ixs
then
let col = ix `mod` 16
extra = if col == 15 then "\n" else ""
in show (PM.indexByteArray ixs ix :: Word8) ++ " " ++ extra ++ ho (ix + 1)
else ""
go :: Int -> String
go !ix = if ix < elemSz
then
let col = ix `mod` i
extra = if col == (i - 1) then "\n" else ""
in showFloat (PM.indexByteArray arr ix :: Double) ++ " " ++ extra ++ go (ix + 1)
else ""
where
!elemSz = P.div (PM.sizeofByteArray arr) (sizeOf (undefined :: Double))
showFloat :: P.RealFloat a => a -> String
showFloat !x = Numeric.showFFloat (Just 2) x ""
sizeOf :: PM.Prim a => a -> Int
sizeOf x = I# (PM.sizeOf# x)
word8ToInt :: Word8 -> Int
word8ToInt !w = P.fromIntegral w
{-# INLINE word8ToInt #-}
intToWord8 :: Int -> Word8
intToWord8 !i = P.fromIntegral i
{-# INLINE intToWord8 #-}
tabulateInternal :: FreqTrain -> Freq
tabulateInternal (FreqTrain m) = runST comp where
comp :: forall s. ST s Freq
comp = do
let allChars :: Set Word8
!allChars = S.union (DMS.keysSet m) (foldMap DMS.keysSet m)
m' :: Map Word8 (Double, Map Word8 Double)
!m' = fmap (\x -> (sum x, x)) m
!sz = min (S.size allChars + 1) 256
!szSq = sz * sz
ixedChars :: [(Word8,Word8)]
!ixedChars = L.zip (P.enumFrom (0 :: Word8)) (S.toList allChars)
ixs <- PM.newByteArray 256
square <- PM.newByteArray (szSq * sizeOf (undefined :: Double))
let fillSquare :: Int -> ST s ()
fillSquare !i = if i < szSq
then do
PM.writeByteArray square i (0 :: Double)
fillSquare (i + 1)
else pure ()
fillSquare 0
PM.fillByteArray ixs 0 256 (intToWord8 (sz - 1))
forM_ ixedChars $ \(ixFst,w8Fst) -> do
PM.writeByteArray ixs (word8ToInt w8Fst) ixFst
forM_ ixedChars $ \(ixSnd,w8Snd) -> do
let r = fromMaybe 0 $ do
(total, m'') <- DMS.lookup w8Fst m'
v <- DMS.lookup w8Snd m''
pure (v / total)
PM.writeByteArray square (sz * (word8ToInt ixFst) + (word8ToInt ixSnd)) r
frozenIxs <- PM.unsafeFreezeByteArray ixs
frozenSquare <- PM.unsafeFreezeByteArray square
pure (Freq sz frozenSquare frozenIxs)
tally :: BC.ByteString
-> FreqTrain
tally (PS _ _ 0) = empty
tally !b = go 0 mempty
where
l :: Int
l = BC.length b - 1
go :: Int -> FreqTrain -> FreqTrain
go !p !fr
| p == l = fr
| otherwise =
let k = BU.unsafeIndex b p
r = BU.unsafeIndex b (p + 1)
in go (p + 1) (mappend (singleton k r 1) fr)
ratio :: Double -> Map Word8 Double -> Double
ratio !weight g = weight / (sum g)
{-# INLINE ratio #-}
type Tal = Map Word8 (Map Word8 Double)
union :: Tal -> Tal -> Tal
union a b = DMS.unionWith (DMS.unionWith (+)) a b
{-# INLINE union #-}