{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Biobase.Types.Partition where

import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import Data.Primitive.Types

import Biobase.Types.Ring



-- | Some default instances. Left out the Num one, so that you have to
-- explicitly instanciate if you want to go around the Ring structure.

newtype Partition = Partition {unPartition' :: Double}
  deriving (Show, Read, Eq, Ord)

-- |

mkPartition :: Double -> Partition
mkPartition x
  | x < 0     = error $ "mkPartition: prob <0: " ++ show x
  | x > 1     = error $ "mkPartition: prob >1: " ++ show x
  | otherwise = Partition $ log x

unPartition :: Partition -> Double
unPartition (Partition x) = exp x


-- | Ring operations over Partition values.

instance Ring Partition where
  (Partition a) .+. (Partition b) = Partition $ logSum a b
  {-# INLINE (.+.) #-}
  (Partition a) .*. (Partition b) = Partition $ a + b
  {-# INLINE (.*.) #-}
  (Partition a) .^. k = Partition $ a * fromIntegral k
  {-# INLINE (.^.) #-}
  (Partition a) .^^. k = error ".^^. not defined for Partition" -- Partition $ a ^^ k
  {-# INLINE (.^^.) #-}
  neg (Partition a) = error $ "negate partition? " ++ show a -- Partition $ negate a
  {-# INLINE neg #-}
  one = Partition 1
  {-# INLINE one #-}
  zero = Partition 0
  {-# INLINE zero #-}
  isZero (Partition a) = a == 0 -- TODO use some epsilon?
  {-# INLINE isZero #-}

logSum :: Double -> Double -> Double
logSum a b
  | a>b       = f a b
  | otherwise = f b a
  where
    f x y = x + log1p (exp $ y - x)
    {-# INLINE f #-}
{-# INLINE logSum #-}



-- * Vector (and Prim) instances.

deriving instance VGM.MVector VU.MVector Partition
deriving instance VG.Vector VU.Vector Partition
deriving instance VU.Unbox Partition
deriving instance Prim Partition



-- * math.h function for log/exp on pm1 are /much/ more efficient (what is
-- haskell doing?)

foreign import ccall unsafe "math.h log1p"
    log1p :: Double -> Double

foreign import ccall unsafe "math.h expm1"
    expm1 :: Double -> Double