{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
--------------------------------------------------------------------
-- |
-- Copyright :  (c) Edward Kmett 2013
-- License   :  BSD3
-- Maintainer:  Edward Kmett <ekmett@gmail.com>
-- Stability :  experimental
-- Portability: non-portable
--
--------------------------------------------------------------------
module Data.Approximate.Mass
  ( Mass(..)
  , (|?), (&?), (^?)
  ) where

import Control.Applicative
import Control.Comonad
import Control.DeepSeq
import Control.Monad
import Data.Binary as Binary
import Data.Bytes.Serial as Bytes
import Data.Copointed
import Data.Data
import Data.Foldable
import Data.Functor.Bind
import Data.Functor.Extend
import Data.Hashable
import Data.Hashable.Extras
import Data.Pointed
import Data.SafeCopy
import Data.Semigroup
import Data.Serialize as Serialize
import Data.Traversable
import Data.Vector.Generic as G
import Data.Vector.Generic.Mutable as M
import Data.Vector.Unboxed as U
import Generics.Deriving
import Numeric.Log

-- | A quantity with a lower-bound on its probability mass. This represents
-- a 'probable value' as a 'Monad' that you can use to calculate progressively
-- less likely consequences.
--
-- /NB:/ These probabilities are all stored in the log domain. This enables us
-- to retain accuracy despite very long multiplication chains. We never add
-- these probabilities so the additional overhead of working in the log domain
-- is never incurred, except on transitioning in and out.
--
-- This is most useful for discrete types, such as
-- small 'Integral' instances or a 'Bounded' 'Enum' like
-- 'Bool'.
--
-- Also note that @('&?')@ and @('|?')@ are able to use knowledge about the
-- function to get better precision on their results than naively using
-- @'liftA2' ('&&')@
data Mass a = Mass {-# UNPACK #-} !(Log Double) a
  deriving (Eq,Ord,Show,Read,Typeable,Data,Generic)

instance Binary a => Binary (Mass a) where
  put (Mass p a) = Binary.put p >> Binary.put a
  get = Mass <$> Binary.get <*> Binary.get

instance Serialize a => Serialize (Mass a) where
  put (Mass p a) = Serialize.put p >> Serialize.put a
  get = Mass <$> Serialize.get <*> Serialize.get

instance Serialize a => SafeCopy (Mass a)

instance Hashable a => Hashable (Mass a)
instance Hashable1 Mass

instance Serial1 Mass where
  serializeWith f (Mass p a) = serialize p >> f a
  deserializeWith m = Mass <$> deserialize <*> m

instance Serial a => Serial (Mass a) where
  serialize (Mass p a) = serialize p >> serialize a
  deserialize = Mass <$> deserialize <*> deserialize

instance Functor Mass where
  fmap f (Mass p a) = Mass p (f a)
  {-# INLINE fmap #-}

instance Foldable Mass where
  foldMap f (Mass _ a) = f a
  {-# INLINE foldMap #-}

newtype instance U.MVector s (Mass a) = MV_Mass (U.MVector s (Log Double,a))
newtype instance U.Vector (Mass a) = V_Mass (U.Vector (Log Double,a))

instance Unbox a => M.MVector U.MVector (Mass a) where
  basicLength (MV_Mass v) = M.basicLength v
  {-# INLINE basicLength #-}
  basicUnsafeSlice i n (MV_Mass v) = MV_Mass $ M.basicUnsafeSlice i n v
  {-# INLINE basicUnsafeSlice #-}
  basicOverlaps (MV_Mass v1) (MV_Mass v2) = M.basicOverlaps v1 v2
  {-# INLINE basicOverlaps #-}
  basicUnsafeNew n = MV_Mass `liftM` M.basicUnsafeNew n
  {-# INLINE basicUnsafeNew #-}
  basicUnsafeReplicate n (Mass p a) = MV_Mass `liftM` M.basicUnsafeReplicate n (p,a)
  {-# INLINE basicUnsafeReplicate #-}
  basicUnsafeRead (MV_Mass v) i = uncurry Mass `liftM` M.basicUnsafeRead v i
  {-# INLINE basicUnsafeRead #-}
  basicUnsafeWrite (MV_Mass v) i (Mass p a) = M.basicUnsafeWrite v i (p,a)
  {-# INLINE basicUnsafeWrite #-}
  basicClear (MV_Mass v) = M.basicClear v
  {-# INLINE basicClear #-}
  basicSet (MV_Mass v) (Mass p a) = M.basicSet v (p,a)
  {-# INLINE basicSet #-}
  basicUnsafeCopy (MV_Mass v1) (MV_Mass v2) = M.basicUnsafeCopy v1 v2
  {-# INLINE basicUnsafeCopy #-}
  basicUnsafeMove (MV_Mass v1) (MV_Mass v2) = M.basicUnsafeMove v1 v2
  {-# INLINE basicUnsafeMove #-}
  basicUnsafeGrow (MV_Mass v) n = MV_Mass `liftM` M.basicUnsafeGrow v n
  {-# INLINE basicUnsafeGrow #-}

instance Unbox a => G.Vector U.Vector (Mass a) where
  basicUnsafeFreeze (MV_Mass v) = V_Mass `liftM` G.basicUnsafeFreeze v
  {-# INLINE basicUnsafeFreeze #-}
  basicUnsafeThaw (V_Mass v) = MV_Mass `liftM` G.basicUnsafeThaw v
  {-# INLINE basicUnsafeThaw #-}
  basicLength (V_Mass v) = G.basicLength v
  {-# INLINE basicLength #-}
  basicUnsafeSlice i n (V_Mass v) = V_Mass $ G.basicUnsafeSlice i n v
  {-# INLINE basicUnsafeSlice #-}
  basicUnsafeIndexM (V_Mass v) i
                = uncurry Mass `liftM` G.basicUnsafeIndexM v i
  {-# INLINE basicUnsafeIndexM #-}
  basicUnsafeCopy (MV_Mass mv) (V_Mass v) = G.basicUnsafeCopy mv v
  {-# INLINE basicUnsafeCopy #-}
  elemseq _ (Mass p a) z
     = G.elemseq (undefined :: U.Vector (Log Double)) p
     $ G.elemseq (undefined :: U.Vector a) a z
  {-# INLINE elemseq #-}

instance NFData a => NFData (Mass a) where
  rnf (Mass _ a) = rnf a `seq` ()
  {-# INLINE rnf #-}

instance Traversable Mass where
  traverse f (Mass p a) = Mass p <$> f a
  {-# INLINE traverse #-}

instance Apply Mass where
  (<.>) = (<*>)
  {-# INLINE (<.>) #-}

instance Pointed Mass where
  point = Mass 1
  {-# INLINE point #-}

instance Copointed Mass where
  copoint (Mass _ a) = a
  {-# INLINE copoint #-}

instance Applicative Mass where
  pure = Mass 1
  {-# INLINE pure #-}
  Mass p f <*> Mass q a = Mass (p * q) (f a)
  {-# INLINE (<*>) #-}

instance Monoid a => Monoid (Mass a) where
  mempty = Mass 1 mempty
  {-# INLINE mempty #-}
  mappend (Mass p a) (Mass q b) = Mass (p * q) (mappend a b)
  {-# INLINE mappend #-}

instance Semigroup a => Semigroup (Mass a) where
  Mass p a <> Mass q b = Mass (p * q) (a <> b)
  {-# INLINE (<>) #-}

instance Bind Mass where
  Mass p a >>- f = case f a of
    Mass q b -> Mass (p * q) b
  {-# INLINE (>>-) #-}

instance Monad Mass where
  return = Mass 1
  {-# INLINE return #-}
  Mass p a >>= f = case f a of
    Mass q b -> Mass (p * q) b
  {-# INLINE (>>=) #-}

instance Extend Mass where
  duplicated (Mass n a) = Mass n (Mass n a)
  {-# INLINE duplicated #-}
  extended f w@(Mass n _) = Mass n (f w)
  {-# INLINE extended #-}

instance Comonad Mass where
  extract (Mass _ a) = a
  {-# INLINE extract #-}
  duplicate (Mass n a) = Mass n (Mass n a)
  {-# INLINE duplicate #-}
  extend f w@(Mass n _) = Mass n (f w)
  {-# INLINE extend #-}

instance ComonadApply Mass where
  (<@>)  = (<*>)
  {-# INLINE (<@>) #-}

infixl 6 ^?
infixr 3 &?
infixr 2 |?

-- | Calculate the logical @and@ of two booleans with confidence lower bounds.
(&?) :: Mass Bool -> Mass Bool -> Mass Bool
Mass p False &? Mass q False = Mass (max p q) False
Mass p False &? Mass _ True  = Mass p False
Mass _ True  &? Mass q False = Mass q False
Mass p True  &? Mass q True  = Mass (p * q) True
{-# INLINE (&?) #-}

-- | Calculate the logical @or@ of two booleans with confidence lower bounds.
(|?) :: Mass Bool -> Mass Bool -> Mass Bool
Mass p False |? Mass q False = Mass (p * q) False
Mass _ False |? Mass q True  = Mass q True
Mass p True  |? Mass _ False = Mass p True
Mass p True  |? Mass q True  = Mass (max p q) True
{-# INLINE (|?) #-}

-- | Calculate the exclusive @or@ of two booleans with confidence lower bounds.
(^?) :: Mass Bool -> Mass Bool -> Mass Bool
Mass p a ^? Mass q b = Mass (p * q) (xor a b) where
  xor True  True  = False
  xor False True  = True
  xor True  False = True
  xor False False = False
  {-# INLINE xor #-}
{-# INLINE (^?) #-}