{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} -------------------------------------------------------------------- -- | -- Copyright : (c) Edward Kmett 2013 -- License : BSD3 -- Maintainer: Edward Kmett -- Stability : experimental -- Portability: non-portable -- -------------------------------------------------------------------- module Data.Approximate.Mass ( Mass(..) , (|?), (&?), (^?) ) where import Control.Applicative import Control.Comonad import Control.DeepSeq import Control.Monad import Data.Binary as Binary import Data.Bytes.Serial as Bytes import Data.Copointed import Data.Data import Data.Foldable import Data.Functor.Bind import Data.Functor.Extend import Data.Hashable import Data.Hashable.Extras import Data.Pointed import Data.SafeCopy import Data.Semigroup import Data.Serialize as Serialize import Data.Traversable import Data.Vector.Generic as G import Data.Vector.Generic.Mutable as M import Data.Vector.Unboxed as U import GHC.Generics import Numeric.Log -- | A quantity with a lower-bound on its probability mass. This represents -- a 'probable value' as a 'Monad' that you can use to calculate progressively -- less likely consequences. -- -- /NB:/ These probabilities are all stored in the log domain. This enables us -- to retain accuracy despite very long multiplication chains. We never add -- these probabilities so the additional overhead of working in the log domain -- is never incurred, except on transitioning in and out. -- -- This is most useful for discrete types, such as -- small 'Integral' instances or a 'Bounded' 'Enum' like -- 'Bool'. -- -- Also note that @('&?')@ and @('|?')@ are able to use knowledge about the -- function to get better precision on their results than naively using -- @'liftA2' ('&&')@ data Mass a = Mass {-# UNPACK #-} !(Log Double) a deriving (Eq,Ord,Show,Read,Typeable,Data,Generic) instance Binary a => Binary (Mass a) where put (Mass p a) = Binary.put p >> Binary.put a get = Mass <$> Binary.get <*> Binary.get instance Serialize a => Serialize (Mass a) where put (Mass p a) = Serialize.put p >> Serialize.put a get = Mass <$> Serialize.get <*> Serialize.get instance Serialize a => SafeCopy (Mass a) instance Hashable a => Hashable (Mass a) instance Hashable1 Mass instance Serial1 Mass where serializeWith f (Mass p a) = serialize p >> f a deserializeWith m = Mass <$> deserialize <*> m instance Serial a => Serial (Mass a) where serialize (Mass p a) = serialize p >> serialize a deserialize = Mass <$> deserialize <*> deserialize instance Functor Mass where fmap f (Mass p a) = Mass p (f a) {-# INLINE fmap #-} instance Foldable Mass where foldMap f (Mass _ a) = f a {-# INLINE foldMap #-} newtype instance U.MVector s (Mass a) = MV_Mass (U.MVector s (Log Double,a)) newtype instance U.Vector (Mass a) = V_Mass (U.Vector (Log Double,a)) instance Unbox a => M.MVector U.MVector (Mass a) where basicLength (MV_Mass v) = M.basicLength v {-# INLINE basicLength #-} basicUnsafeSlice i n (MV_Mass v) = MV_Mass $ M.basicUnsafeSlice i n v {-# INLINE basicUnsafeSlice #-} basicOverlaps (MV_Mass v1) (MV_Mass v2) = M.basicOverlaps v1 v2 {-# INLINE basicOverlaps #-} basicUnsafeNew n = MV_Mass `liftM` M.basicUnsafeNew n {-# INLINE basicUnsafeNew #-} basicUnsafeReplicate n (Mass p a) = MV_Mass `liftM` M.basicUnsafeReplicate n (p,a) {-# INLINE basicUnsafeReplicate #-} basicUnsafeRead (MV_Mass v) i = uncurry Mass `liftM` M.basicUnsafeRead v i {-# INLINE basicUnsafeRead #-} basicUnsafeWrite (MV_Mass v) i (Mass p a) = M.basicUnsafeWrite v i (p,a) {-# INLINE basicUnsafeWrite #-} basicClear (MV_Mass v) = M.basicClear v {-# INLINE basicClear #-} basicSet (MV_Mass v) (Mass p a) = M.basicSet v (p,a) {-# INLINE basicSet #-} basicUnsafeCopy (MV_Mass v1) (MV_Mass v2) = M.basicUnsafeCopy v1 v2 {-# INLINE basicUnsafeCopy #-} basicUnsafeMove (MV_Mass v1) (MV_Mass v2) = M.basicUnsafeMove v1 v2 {-# INLINE basicUnsafeMove #-} basicUnsafeGrow (MV_Mass v) n = MV_Mass `liftM` M.basicUnsafeGrow v n {-# INLINE basicUnsafeGrow #-} instance Unbox a => G.Vector U.Vector (Mass a) where basicUnsafeFreeze (MV_Mass v) = V_Mass `liftM` G.basicUnsafeFreeze v {-# INLINE basicUnsafeFreeze #-} basicUnsafeThaw (V_Mass v) = MV_Mass `liftM` G.basicUnsafeThaw v {-# INLINE basicUnsafeThaw #-} basicLength (V_Mass v) = G.basicLength v {-# INLINE basicLength #-} basicUnsafeSlice i n (V_Mass v) = V_Mass $ G.basicUnsafeSlice i n v {-# INLINE basicUnsafeSlice #-} basicUnsafeIndexM (V_Mass v) i = uncurry Mass `liftM` G.basicUnsafeIndexM v i {-# INLINE basicUnsafeIndexM #-} basicUnsafeCopy (MV_Mass mv) (V_Mass v) = G.basicUnsafeCopy mv v {-# INLINE basicUnsafeCopy #-} elemseq _ (Mass p a) z = G.elemseq (undefined :: U.Vector (Log Double)) p $ G.elemseq (undefined :: U.Vector a) a z {-# INLINE elemseq #-} instance NFData a => NFData (Mass a) where rnf (Mass _ a) = rnf a `seq` () {-# INLINE rnf #-} instance Traversable Mass where traverse f (Mass p a) = Mass p <$> f a {-# INLINE traverse #-} instance Apply Mass where (<.>) = (<*>) {-# INLINE (<.>) #-} instance Pointed Mass where point = Mass 1 {-# INLINE point #-} instance Copointed Mass where copoint (Mass _ a) = a {-# INLINE copoint #-} instance Applicative Mass where pure = Mass 1 {-# INLINE pure #-} Mass p f <*> Mass q a = Mass (p * q) (f a) {-# INLINE (<*>) #-} instance Monoid a => Monoid (Mass a) where mempty = Mass 1 mempty {-# INLINE mempty #-} mappend (Mass p a) (Mass q b) = Mass (p * q) (mappend a b) {-# INLINE mappend #-} instance Semigroup a => Semigroup (Mass a) where Mass p a <> Mass q b = Mass (p * q) (a <> b) {-# INLINE (<>) #-} instance Bind Mass where Mass p a >>- f = case f a of Mass q b -> Mass (p * q) b {-# INLINE (>>-) #-} instance Monad Mass where return = Mass 1 {-# INLINE return #-} Mass p a >>= f = case f a of Mass q b -> Mass (p * q) b {-# INLINE (>>=) #-} instance Extend Mass where duplicated (Mass n a) = Mass n (Mass n a) {-# INLINE duplicated #-} extended f w@(Mass n _) = Mass n (f w) {-# INLINE extended #-} instance Comonad Mass where extract (Mass _ a) = a {-# INLINE extract #-} duplicate (Mass n a) = Mass n (Mass n a) {-# INLINE duplicate #-} extend f w@(Mass n _) = Mass n (f w) {-# INLINE extend #-} instance ComonadApply Mass where (<@>) = (<*>) {-# INLINE (<@>) #-} infixl 6 ^? infixr 3 &? infixr 2 |? -- | Calculate the logical @and@ of two booleans with confidence lower bounds. (&?) :: Mass Bool -> Mass Bool -> Mass Bool Mass p False &? Mass q False = Mass (max p q) False Mass p False &? Mass _ True = Mass p False Mass _ True &? Mass q False = Mass q False Mass p True &? Mass q True = Mass (p * q) True {-# INLINE (&?) #-} -- | Calculate the logical @or@ of two booleans with confidence lower bounds. (|?) :: Mass Bool -> Mass Bool -> Mass Bool Mass p False |? Mass q False = Mass (p * q) False Mass _ False |? Mass q True = Mass q True Mass p True |? Mass _ False = Mass p True Mass p True |? Mass q True = Mass (max p q) True {-# INLINE (|?) #-} -- | Calculate the exclusive @or@ of two booleans with confidence lower bounds. (^?) :: Mass Bool -> Mass Bool -> Mass Bool Mass p a ^? Mass q b = Mass (p * q) (xor a b) where xor True True = False xor False True = True xor True False = True xor False False = False {-# INLINE xor #-} {-# INLINE (^?) #-}