module Numeric.Log
  ( Log(..)
  , Precise(..)
  , sum
  ) where
import Prelude hiding (maximum, sum)
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Control.Comonad
import Control.DeepSeq
import Control.Monad
import Data.Binary as Binary
import Data.Bytes.Serial
import Data.Complex
import Data.Data
import Data.Distributive
import Data.Foldable as Foldable hiding (sum)
import Data.Functor.Bind
import Data.Functor.Extend
import Data.Hashable
import Data.Hashable.Lifted
import Data.Int
import Data.List as List hiding (sum)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Semigroup
import Data.Semigroup.Foldable
import Data.Semigroup.Traversable
import Data.Serialize as Serialize
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable
#endif
import Data.Vector.Unboxed as U hiding (sum)
import Data.Vector.Generic as G hiding (sum)
import Data.Vector.Generic.Mutable as M
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import Text.Read as T
import Text.Show as T
newtype Log a = Exp { ln :: a } deriving (Eq,Ord,Data,Typeable,Generic)
instance (Floating a, Show a) => Show (Log a) where
  showsPrec d (Exp a) = T.showsPrec d (exp a)
instance (Floating a, Read a) => Read (Log a) where
  readPrec = Exp . log <$> step T.readPrec
instance Binary a => Binary (Log a) where
  put = Binary.put . ln
  
  get = Exp <$> Binary.get
  
instance Serialize a => Serialize (Log a) where
  put = Serialize.put . ln
  
  get = Exp <$> Serialize.get
  
instance Serial a => Serial (Log a) where
  serialize = serialize . ln
  deserialize = Exp <$> deserialize
instance Serial1 Log where
  serializeWith f = f . ln
  deserializeWith m = Exp <$> m
instance Functor Log where
  fmap f (Exp a) = Exp (f a)
  
instance Hashable a => Hashable (Log a) where
  hashWithSalt i (Exp a) = hashWithSalt i a
  
instance Hashable1 Log where
  liftHashWithSalt hws i (Exp a) = hws i a
  
instance Storable a => Storable (Log a) where
  sizeOf = sizeOf . ln
  
  alignment = alignment . ln
  
  peek ptr = Exp <$> peek (castPtr ptr)
  
  poke ptr (Exp a) = poke (castPtr ptr) a
  
instance NFData a => NFData (Log a) where
  rnf (Exp a) = rnf a
  
instance Foldable Log where
  foldMap f (Exp a) = f a
  
instance Foldable1 Log where
  foldMap1 f (Exp a) = f a
  
instance Traversable Log where
  traverse f (Exp a) = Exp <$> f a
  
instance Traversable1 Log where
  traverse1 f (Exp a) = Exp <$> f a
  
instance Distributive Log where
  distribute = Exp . fmap ln
  
instance Extend Log where
  extended f w@Exp{} = Exp (f w)
  
instance Comonad Log where
  extract (Exp a) = a
  
  extend f w@Exp{} = Exp (f w)
  
instance Applicative Log where
  pure = Exp
  
  Exp f <*> Exp a = Exp (f a)
  
instance ComonadApply Log where
  Exp f <@> Exp a = Exp (f a)
  
instance Apply Log where
  Exp f <.> Exp a = Exp (f a)
  
instance Bind Log where
  Exp a >>- f = f a
  
instance Monad Log where
  return = pure
  
  Exp a >>= f = f a
  
instance (RealFloat a, Precise a, Enum a) => Enum (Log a) where
  succ a = a + 1
  
  pred a = a  1
  
  toEnum   = fromIntegral
  
  fromEnum = round . exp . ln
  
  enumFrom (Exp a) = [ Exp (log b) | b <- Prelude.enumFrom (exp a) ]
  
  enumFromThen (Exp a) (Exp b) = [ Exp (log c) | c <- Prelude.enumFromThen (exp a) (exp b) ]
  
  enumFromTo (Exp a) (Exp b) = [ Exp (log c) | c <- Prelude.enumFromTo (exp a) (exp b) ]
  
  enumFromThenTo (Exp a) (Exp b) (Exp c) = [ Exp (log d) | d <- Prelude.enumFromThenTo (exp a) (exp b) (exp c) ]
  
negInf :: Fractional a => a
negInf = (1/0)
instance (Precise a, RealFloat a) => Num (Log a) where
  Exp a * Exp b = Exp (a + b)
  
  Exp a + Exp b
    | a == b && isInfinite a && isInfinite b = Exp a
    | a >= b    = Exp (a + log1pexp (b  a))
    | otherwise = Exp (b + log1pexp (a  b))
  
  Exp a  Exp b
    | isInfinite a && isInfinite b && a < 0 && b < 0 = Exp negInf
    | otherwise = Exp (a + log1mexp (b  a))
  
  signum a
    | a == 0    = Exp negInf 
    | a > 0     = Exp 0      
    | otherwise = Exp (0/0)  
  
  negate (Exp a)
    | isInfinite a && a < 0 = Exp negInf
    | otherwise             = Exp (0/0)
  
  abs = id
  
  fromInteger = Exp . log . fromInteger
  
instance (Precise a, RealFloat a) => Fractional (Log a) where
  
  Exp a / Exp b = Exp (ab)
  
  fromRational = Exp . log . fromRational
  
instance (Precise a, RealFloat a) => RealFrac (Log a) where
  properFraction l
    | ln l < 0  = (0, l)
    | otherwise = (\(b,a) -> (b, Exp $ log a)) $ properFraction $ exp (ln l)
newtype instance U.MVector s (Log a) = MV_Log (U.MVector s a)
newtype instance U.Vector    (Log a) = V_Log  (U.Vector    a)
instance (RealFloat a, Unbox a) => Unbox (Log a)
instance Unbox a => M.MVector U.MVector (Log a) where
  
  
  
  
  
  
  
  
#if MIN_VERSION_vector(0,11,0)
  
#endif
  
  
  
  basicLength (MV_Log v) = M.basicLength v
  basicUnsafeSlice i n (MV_Log v) = MV_Log $ M.basicUnsafeSlice i n v
  basicOverlaps (MV_Log v1) (MV_Log v2) = M.basicOverlaps v1 v2
  basicUnsafeNew n = MV_Log `liftM` M.basicUnsafeNew n
  basicUnsafeReplicate n (Exp x) = MV_Log `liftM` M.basicUnsafeReplicate n x
  basicUnsafeRead (MV_Log v) i = Exp `liftM` M.basicUnsafeRead v i
  basicUnsafeWrite (MV_Log v) i (Exp x) = M.basicUnsafeWrite v i x
  basicClear (MV_Log v) = M.basicClear v
#if MIN_VERSION_vector(0,11,0)
  basicInitialize (MV_Log v) = M.basicInitialize v
#endif
  basicSet (MV_Log v) (Exp x) = M.basicSet v x
  basicUnsafeCopy (MV_Log v1) (MV_Log v2) = M.basicUnsafeCopy v1 v2
  basicUnsafeGrow (MV_Log v) n = MV_Log `liftM` M.basicUnsafeGrow v n
instance (RealFloat a, Unbox a) => G.Vector U.Vector (Log a) where
  
  
  
  
  
  
  basicUnsafeFreeze (MV_Log v) = V_Log `liftM` G.basicUnsafeFreeze v
  basicUnsafeThaw (V_Log v) = MV_Log `liftM` G.basicUnsafeThaw v
  basicLength (V_Log v) = G.basicLength v
  basicUnsafeSlice i n (V_Log v) = V_Log $ G.basicUnsafeSlice i n v
  basicUnsafeIndexM (V_Log v) i = Exp `liftM` G.basicUnsafeIndexM v i
  basicUnsafeCopy (MV_Log mv) (V_Log v) = G.basicUnsafeCopy mv v
  elemseq _ (Exp x) z = G.elemseq (undefined :: U.Vector a) x z
instance (Precise a, RealFloat a, Ord a) => Real (Log a) where
  toRational (Exp a) = toRational (exp a)
  
data Acc1 a = Acc1  !Int64 !a
instance (Precise a, RealFloat a) => Semigroup (Log a) where
  (<>) = (+)
  
  sconcat (Exp z :| zs) = Exp $ case List.foldl' step1 (Acc1 0 z) zs of
    Acc1 nm1 a
      | isInfinite a -> a
      | otherwise    -> a + log1p (List.foldl' (step2 a) 0 zs + fromIntegral nm1)
    where
      step1 (Acc1 n y) (Exp x) = Acc1 (n + 1) (max x y)
      step2 a r (Exp x) = r + expm1 (x  a)
  
instance (Precise a, RealFloat a) => Monoid (Log a) where
  mempty  = Exp negInf
  
#if !(MIN_VERSION_base(4,11,0))
  mappend = (<>)
#endif
  mconcat [] = 0
  mconcat (x:xs) = sconcat (x :| xs)
logMap :: Floating a => (a -> a) -> Log a -> Log a
logMap f = Exp . log . f . exp . ln
data Acc a = Acc  !Int64 !a | None
sum :: (RealFloat a, Precise a, Foldable f) => f (Log a) -> Log a
sum xs = Exp $ case Foldable.foldl' step1 None xs of
  None -> negInf
  Acc nm1 a
    | isInfinite a -> a
    | otherwise    -> a + log1p (Foldable.foldl' (step2 a) 0 xs + fromIntegral nm1)
  where
    step1 None      (Exp x) = Acc 0 x
    step1 (Acc n y) (Exp x) = Acc (n + 1) (max x y)
    step2 a r (Exp x) = r + expm1 (x  a)
instance (RealFloat a, Precise a) => Floating (Log a) where
  pi = Exp (log pi)
  
  exp (Exp a) = Exp (exp a)
  
  log (Exp a) = Exp (log a)
  
  Exp b ** Exp e = Exp (b * exp e)
  
  sqrt (Exp a) = Exp (a / 2)
  
  logBase (Exp a) (Exp b) = Exp (log (logBase (exp a) (exp b)))
  
  sin = logMap sin
  
  cos = logMap cos
  
  tan = logMap tan
  
  asin = logMap asin
  
  acos = logMap acos
  
  atan = logMap atan
  
  sinh = logMap sinh
  
  cosh = logMap cosh
  
  tanh = logMap tanh
  
  asinh = logMap asinh
  
  acosh = logMap acosh
  
  atanh = logMap atanh
  
class Floating a => Precise a where
  
  
  
  
  
  
  
  log1p :: a -> a
  
  
  
  
  
  
  
  
  
  
  
  
  expm1 :: a -> a
  log1pexp :: a -> a
  log1pexp a = log1p (exp a)
  log1mexp :: a -> a
  log1mexp a = log1p (negate (exp a))
instance Precise Double where
  log1p = c_log1p
  
  expm1 = c_expm1
  
  log1mexp a
    | a <= log 2 = log (negate (expm1 a))
    | otherwise  = log1p (negate (exp a))
  
  log1pexp a
    | a <= 18   = log1p (exp a)
    | a <= 100  = a + exp (negate a)
    | otherwise = a
  
instance Precise Float where
  log1p = c_log1pf
  
  expm1 = c_expm1f
  
  log1mexp a | a <= log 2 = log (negate (expm1 a))
             | otherwise  = log1p (negate (exp a))
  
  log1pexp a
    | a <= 18   = log1p (exp a)
    | a <= 100  = a + exp (negate a)
    | otherwise = a
  
instance (RealFloat a, Precise a) => Precise (Complex a) where
  expm1 x@(a :+ b)
    | a*a + b*b < 1, u <- expm1 a, v <- sin (b/2), w <- 2*v*v = (u*w+u+w) :+ (u+1)*sin b
    | otherwise = exp x  1
  
  log1p x@(a :+ b)
    | abs a < 0.5 && abs b < 0.5, u <- 2*a+a*a+b*b = log1p (u/(1+sqrt (u+1))) :+ atan2 (1 + a) b
    | otherwise = log (1 + x)
  
#ifdef __USE_FFI__
foreign import ccall unsafe "math.h log1p" c_log1p :: Double -> Double
foreign import ccall unsafe "math.h expm1" c_expm1 :: Double -> Double
foreign import ccall unsafe "math.h expm1f" c_expm1f :: Float -> Float
foreign import ccall unsafe "math.h log1pf" c_log1pf :: Float -> Float
#else
c_log1p :: Double -> Double
c_log1p x = log (1 + x)
c_expm1 :: Double -> Double
c_expm1 x = exp x  1
c_expm1f :: Float -> Float
c_expm1f x = exp x  1
c_log1pf :: Float -> Float
c_log1pf x = log (1 + x)
#endif