--------------------------------------------------------------------------------

{-# language BangPatterns               #-}
{-# language DeriveDataTypeable         #-}
{-# language ExplicitNamespaces         #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# language MagicHash                  #-}
{-# language NoImplicitPrelude          #-}
{-# language ScopedTypeVariables        #-}
{-# language UnboxedTuples              #-}
{-# language TypeFamilies               #-}

--------------------------------------------------------------------------------

{-| This is the internal module to 'Freq'.
    The primary differences are that this
    module exports the typeclass 'Freaky',
    as well as the data constructors of
    'FreqTrain' and 'Freq'.
-}

module Freq.Internal
  ( -- * Frequency table type
    FreqTrain(..)

    -- * Construction
  , empty
  , singleton

  , tabulate

    -- * Training
  , train
  , trainWith
  , trainWithMany

    -- * Using a trained model
  , Freq(..)
  , measure
  , Freaky(prob)

    -- * Pretty Printing
  , prettyFreqTrain
  ) where

--------------------------------------------------------------------------------

import           Prelude
  ()

import           Control.Applicative (Applicative(pure))
import           Control.DeepSeq (NFData)
import           Control.Monad (Monad((>>=)), (>>), forM_)
import           Control.Monad.ST (ST,runST)
import           Data.Binary (Binary(..))
import           Data.Bool (otherwise)
import           Data.ByteString.Internal (ByteString(..), w2c)
import           Data.Data (Data)
import           Data.Eq (Eq((==)))
import           Data.Foldable (Foldable(foldMap, sum))
import           Data.Function ((.), ($))
import           Data.Functor (fmap)
import           Data.List ((++))
import           Data.Map.Strict.Internal (Map)
import           Data.Maybe (Maybe(Just, Nothing), fromMaybe)
import           Data.Monoid (Monoid(mempty, mappend))
import           Data.Ord (Ord(min, (<)))
import           Data.Primitive.ByteArray (ByteArray,foldrByteArray)
import           Data.Semigroup (Semigroup((<>)))
import           Data.Set (Set)
import           Data.String (String)
import           Data.Word (Word8)

import           GHC.Base (Double, Int(I#), build)
import           GHC.Err (undefined)
import           GHC.IO (FilePath, IO)
import           GHC.Num ((+), (*), (-))
import           GHC.Read (Read)
import           GHC.Real ((/), mod)
import           GHC.Show (Show(show))

import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Unsafe as BU
import qualified Data.Map.Strict as DMS
import qualified Data.Primitive.ByteArray as PM
import qualified Data.Primitive.Types as PM
import qualified Data.Set as S
import qualified GHC.OldList as L
import qualified Numeric as Numeric
import qualified Prelude as P

--------------------------------------------------------------------------------

-- | @'Freaky'@ is a typeclass that wraps the @'prob'@ function,
--   which allows for an extensible definition of @'measure'@.
--
--   It is used internally.
class Freaky a where
  -- | Given a Frequency table and characters 'c1' and 'c2',
  --   what is the probability that 'c1' follows 'c2'?
  prob :: a -> Word8 -> Word8 -> Double

-- | Given a Frequency table and a @'ByteString'@, @'measure'@
--   returns the probability that the @'ByteString'@ is not
--   randomised. The accuracy of @'measure'@ is is heavily affected
--   by your training data.
measure :: Freaky a => a -> BC.ByteString -> Double
measure _ (PS _ _ 0) = 0
measure _ (PS _ _ 1) = 0
measure f !b         = (go 0 0) / (P.fromIntegral l)
  where
    l :: Int
    l = BC.length b - 1

    go :: Int -> Double -> Double
    go !p !acc
      | p == l = acc
      | otherwise =
          let k = BU.unsafeIndex b p
              r = BU.unsafeIndex b (p + 1)
          in go (p + 1) (prob f k r + acc)
{-# INLINE measure #-}

--------------------------------------------------------------------------------

-- | A @'FreqTrain'@ is a digram-based frequency table.
--   
--   One can construct a @'FreqTrain'@ with @'train'@,
--   @'trainWith'@, or @'trainWithMany'@.
--
--   One can use a trained @'FreqTrain'@ with @'prob'@
--   and @'measure'@.
--   
--   @'mappend' == '<>'@ will add the values of each
--   of the matching keys.
--
--   It is highly recommended to convert a @'FreqTrain'@
--   to a @'Freq'@ with @'tabulate'@ before using the trained model,
--   because @'Freq'@s have /O(1)/ reads as well as significantly
--   faster constant-time operations, however keep in mind
--   that @'Freq'@s cannot be neither modified nor converted
--   back to a @'FreqTrain'@.
--
newtype FreqTrain = FreqTrain { _getFreqTrain :: Map Word8 (Map Word8 Double) }
  deriving
    ( Data
    , Eq
    , NFData
    , Ord
    , Read
    , Show
    )

instance Freaky FreqTrain where
  prob (FreqTrain f) w1 w2 =
    case DMS.lookup w1 f of
      Nothing -> 0
      Just g -> case DMS.lookup w2 g of
        Nothing -> 0
        Just weight -> ratio weight g
  {-# INLINE prob #-}

instance Semigroup FreqTrain where
  {-# INLINE (<>) #-}
  (FreqTrain a) <> (FreqTrain b) = FreqTrain (union a b)

instance Monoid FreqTrain where
  {-# INLINE mempty #-}
  mempty  = empty
  {-# INLINE mappend #-}
  (FreqTrain a) `mappend` (FreqTrain b) = FreqTrain (union a b)

--------------------------------------------------------------------------------

-- | /O(1)/. The empty frequency table.
empty :: FreqTrain
empty = FreqTrain DMS.empty
{-# INLINE empty #-}

-- | /O(1)/. A Frequency table with a single entry.
singleton :: Word8  -- ^ Outer key
          -> Word8  -- ^ Inner key
          -> Double -- ^ Weight
          -> FreqTrain   -- ^ The singleton frequency table
singleton k ka w = FreqTrain $ DMS.singleton k (DMS.singleton ka w)
{-# INLINE singleton #-}

-- | Optimise a 'FreqTrain' for /O(1)/ read access.
tabulate :: FreqTrain -> Freq
tabulate = tabulateInternal
{-# INLINE tabulate #-}

--------------------------------------------------------------------------------

-- | Given a @'BC.ByteString'@ consisting of training data,
--   build a Frequency table.
train :: BC.ByteString
      -> FreqTrain
train !b = tally b
{-# INLINE train #-}

-- | Given a @'FilePath'@ containing training data, build a
--   Frequency table inside of the @'IO'@ monad.
trainWith :: FilePath -- ^ @'FilePath'@ containing training data
          -> IO FreqTrain  -- ^ Frequency table generated as a result of training, inside of @'IO'@.
trainWith path = BC.readFile path >>= (pure . tally)
{-# INLINE trainWith #-}

-- | Given a list of @'FilePath'@ containing training data,
--   build a Frequency table inside of the @'IO'@ monad.
trainWithMany :: Foldable t
              => t FilePath -- ^ @'FilePath'@s containing training data
              -> IO FreqTrain    -- ^ Frequency table generated as a result of training, inside of @'IO'@.
trainWithMany paths = foldMap trainWith paths
{-# INLINE trainWithMany #-}

--------------------------------------------------------------------------------

-- | Pretty-print a @'FreqTrain'@.
--
prettyFreqTrain :: FreqTrain -> IO ()
prettyFreqTrain (FreqTrain m)
  = DMS.foldMapWithKey
      (\c1 m' ->
         P.putStrLn (if c1 == 10 then "\\n" else [w2c c1])
           >> DMS.foldMapWithKey
               (\c2 prb -> P.putStrLn ("  " ++ [w2c c2] ++ " " ++ P.show (P.round prb :: Int))) m') m

--------------------------------------------------------------------------------

-- | A variant of @'FreqTrain'@ that holds identical information but
--   is optimised for reads. There are no operations that imbue
--   a @'Freq'@ with additional information.
--
--   Reading from a @'Freq'@ is orders of magnitude faster
--   than reading from a @'FreqTrain'@. It is /highly/
--   recommended that you use your trained model by first
--   converting a @'FreqTrain'@ to a @'Freq'@ with @'tabulate'@.
data Freq = Freq
  { _Dim :: !Int
    -- ^ Width and height of square 2d array
  , _2d  :: !ByteArray
    -- ^ Square two-dimensional array of Double, maps first char and second char to probability
  , _Flat :: !ByteArray
    -- ^ Array of Word8, length 256, acts as map from Word8 to table row/column index
  }
  deriving (Eq)

toList :: PM.Prim a => ByteArray -> [a]
toList xs = build (\c n -> foldrByteArray c n xs)
{-# INLINE toList #-}

toDoubles :: ByteArray -> [Double]
toDoubles = toList

toWord8s :: ByteArray -> [Word8]
toWord8s = toList

instance Binary Freq where
  put (Freq dim ds ws) = put (dim,toDoubles ds,toWord8s ws)
  get = do
    (dim :: Int,asDoubles :: [Double],asWord8s :: [Word8]) <- get
    pure (Freq dim (PM.byteArrayFromList asDoubles) (PM.byteArrayFromList asWord8s))

instance Freaky Freq where
  {-# INLINE prob #-}
  prob (Freq sz square ixs) chrFst chrSnd =
     let !ixFst = word8ToInt (PM.indexByteArray ixs (word8ToInt chrFst))
         !ixSnd = word8ToInt (PM.indexByteArray ixs (word8ToInt chrSnd))
     in PM.indexByteArray square (sz * ixFst + ixSnd)

-- | This exists for debugging purposes
instance P.Show Freq where
  show (Freq i arr ixs) =
      P.show i ++ "x" ++ show i
   ++ "\n"
   ++ "\n2D Array: \n"
   ++ go 0
   ++ "\n256 Array: \n"
   ++ ho 0
    where
      ho :: Int -> String
      ho !ix = if ix < PM.sizeofByteArray ixs
        then
          let col = ix `mod` 16
              extra = if col == 15 then "\n" else ""
          in show (PM.indexByteArray ixs ix :: Word8) ++ " " ++ extra ++ ho (ix + 1)
        else ""

      go :: Int -> String
      go !ix = if ix < elemSz
        then
          let col = ix `mod` i
              extra = if col == (i - 1) then "\n" else ""
          in showFloat (PM.indexByteArray arr ix :: Double) ++ " " ++ extra ++ go (ix + 1)
        else ""
        where
          !elemSz = P.div (PM.sizeofByteArray arr) (sizeOf (undefined :: Double))
          showFloat :: P.RealFloat a => a -> String
          showFloat !x = Numeric.showFFloat (Just 2) x ""

--------------------------------------------------------------------
--  Internal Section                                              --
--------------------------------------------------------------------

sizeOf :: PM.Prim a => a -> Int
sizeOf x = I# (PM.sizeOf# x)

word8ToInt :: Word8 -> Int
word8ToInt !w = P.fromIntegral w
{-# INLINE word8ToInt #-}

intToWord8 :: Int -> Word8
intToWord8 !i = P.fromIntegral i
{-# INLINE intToWord8 #-}

-- | Optimise a 'FreqTrain' for /O(1)/ read access.
--
tabulateInternal :: FreqTrain -> Freq
tabulateInternal (FreqTrain m) = runST comp where
  comp :: forall s. ST s Freq
  comp = do
    let allChars :: Set Word8
        !allChars = S.union (DMS.keysSet m) (foldMap DMS.keysSet m)
        m' :: Map Word8 (Double, Map Word8 Double)
        !m' = fmap (\x -> (sum x, x)) m
        !sz = min (S.size allChars + 1) 256
        !szSq = sz * sz
        ixedChars :: [(Word8,Word8)]
        !ixedChars = L.zip (P.enumFrom (0 :: Word8)) (S.toList allChars)
    ixs <- PM.newByteArray 256
    square <- PM.newByteArray (szSq * sizeOf (undefined :: Double))
    let fillSquare :: Int -> ST s ()
        fillSquare !i = if i < szSq
          then do
            PM.writeByteArray square i (0 :: Double)
            fillSquare (i + 1)
          else pure ()
    fillSquare 0
    PM.fillByteArray ixs 0 256 (intToWord8 (sz - 1))
    forM_ ixedChars $ \(ixFst,w8Fst) -> do
      PM.writeByteArray ixs (word8ToInt w8Fst) ixFst --w8Fst
      forM_ ixedChars $ \(ixSnd,w8Snd) -> do
        let r = fromMaybe 0 $ do
              (total, m'') <- DMS.lookup w8Fst m'
              v <- DMS.lookup w8Snd m''
              pure (v / total)
        PM.writeByteArray square (sz * (word8ToInt ixFst) + (word8ToInt ixSnd)) r
    frozenIxs <- PM.unsafeFreezeByteArray ixs
    frozenSquare <- PM.unsafeFreezeByteArray square
    pure (Freq sz frozenSquare frozenIxs)

-- | Build a frequency table from a ByteString.
tally :: BC.ByteString -- ^ ByteString with which the FreqTrain will be built
      -> FreqTrain     -- ^ Resulting FreqTrain
tally (PS _ _ 0) = empty
tally !b = go 0 mempty
  where
    l :: Int
    l = BC.length b - 1

    go :: Int -> FreqTrain -> FreqTrain
    go !p !fr
      | p == l = fr
      | otherwise =
          let k = BU.unsafeIndex b p
              r = BU.unsafeIndex b (p + 1)
          in go (p + 1) (mappend (singleton k r 1) fr)

ratio :: Double -> Map Word8 Double -> Double
ratio !weight g = weight / (sum g)
{-# INLINE ratio #-}

-- A convenience type synonym for internal use.
-- Please /do not/ ask me to rename this to "Tally".
-- https://github.com/andrewthad did this.
-- He is no longer permitted to leave the cave.
--
-- Tal is named after Mikhail Tal.
type Tal = Map Word8 (Map Word8 Double)

-- union two 'Tal', summing the weights belonging to the same keys.
union :: Tal -> Tal -> Tal
union a b = DMS.unionWith (DMS.unionWith (+)) a b
{-# INLINE union #-}