-- | This module provides log-domain functionality. Ed Kmett provides, with
-- @log-domain@, a generic way to handle numbers in the log-domain, some which
-- is used under the hood here. We want some additional type safety and also
-- connect with the 'SemiRing' module.

module Numeric.LogDomain where

import Control.Monad.Except
import Numeric.Log as NL
import qualified Data.Vector.Fusion.Stream.Monadic as SM
import qualified Data.Vector.Fusion.Util as SM
import Debug.Trace
import Numeric



-- | Instances for @LogDomain x@ should be for specific types.

class LogDomain x where
  -- | The type family to connect a type @x@ with the type @Ln x@ in the
  -- log-domain.
  type Ln x  *
  -- | Transport a value in @x@ into the log-domain. @logdom@ should throw an
  -- exception if @log x@ is not valid.
  logdom  (MonadError String m)  x  m (Ln x)
  -- | Unsafely transport x into the log-domain.
  unsafelogdom  x  Ln x
  -- | Transport a value @Ln x@ back into the linear domain @x@.
  lindom  Ln x  x



instance LogDomain Double where
  type Ln Double = Log Double
  {-# Inline logdom #-}
  logdom :: Double -> m (Ln Double)
logdom Double
x
    | Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0     = String -> m (Log Double)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"log of negative number"
    | Bool
otherwise = Log Double -> m (Log Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Log Double -> m (Log Double)) -> Log Double -> m (Log Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ln Double
forall x. LogDomain x => x -> Ln x
unsafelogdom Double
x
  {-# Inline unsafelogdom #-}
  unsafelogdom :: Double -> Ln Double
unsafelogdom = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log
  {-# Inline lindom #-}
  lindom :: Ln Double -> Double
lindom = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln



-- | This is similar to 'Numeric.Log.sum' but requires only one pass over the
-- data. It will be useful if the first two elements in the stream are large.
-- If the user has some control over how the stream is generated, this function
-- might show better performance than 'Numeric.Log.sum' and better numeric
-- stability than 'fold 0 (+)'
--
-- TODO this needs to be benchmarked against @fold 0 (+)@, since in
-- @DnaProteinAlignment@ @sumS@ seems to be slower!

sumS
   (Monad m, Ord a, RealFloat a, Show a)
   Log a  SM.Stream m (Log a)
   m (Log a)
{-# Inline sumS #-}
sumS :: Log a -> Stream m (Log a) -> m (Log a)
sumS Log a
zero (SM.Stream s -> m (Step s (Log a))
step s
s0) = SPEC -> Log a -> s -> m (Log a)
sLoop1 SPEC
SM.SPEC Log a
zero s
s0
  where
    -- we need to find the first @x@ that is not @(-1/0)@ to handle @x-m@
    -- correctly. We loop @sLoop1@ until we have the first finite @y@ and use
    -- that as the @m@ for @sLoop2@.
    sLoop1 :: SPEC -> Log a -> s -> m (Log a)
sLoop1 SPEC
SM.SPEC (Exp a
x) s
s = s -> m (Step s (Log a))
step s
s m (Step s (Log a)) -> (Step s (Log a) -> m (Log a)) -> m (Log a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Step s (Log a)
SM.Done        Log a -> m (Log a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Log a -> m (Log a)) -> Log a -> m (Log a)
forall a b. (a -> b) -> a -> b
$ a -> Log a
forall a. a -> Log a
Exp a
x
      SM.Skip    s
s1  SPEC -> Log a -> s -> m (Log a)
sLoop1 SPEC
SM.SPEC (a -> Log a
forall a. a -> Log a
Exp a
x) s
s1
      SM.Yield (Exp a
y) s
s2
        | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
y  SPEC -> Log a -> s -> m (Log a)
sLoop1 SPEC
SM.SPEC (a -> Log a
forall a. a -> Log a
Exp (a -> Log a) -> a -> Log a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. Ord a => a -> a -> a
max a
x a
y) s
s2  -- either (1/0) or (-1/0) are handled correctly
        | Bool
otherwise     SPEC -> a -> Int -> a -> s -> m (Log a)
sLoop2 SPEC
SM.SPEC a
m (Int
1Int) (a -> a
forall a. Floating a => a -> a
expm1 (a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
m) a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
expm1 (a
ya -> a -> a
forall a. Num a => a -> a -> a
-a
m)) s
s2
        where m :: a
m = a -> a -> a
forall a. Ord a => a -> a -> a
max a
x a
y
    -- from here on we are fine
    sLoop2 :: SPEC -> a -> Int -> a -> s -> m (Log a)
sLoop2 SPEC
SM.SPEC a
m Int
cnt a
acc s
s = s -> m (Step s (Log a))
step s
s m (Step s (Log a)) -> (Step s (Log a) -> m (Log a)) -> m (Log a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Step s (Log a)
SM.Done        Log a -> m (Log a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Log a -> m (Log a)) -> Log a -> m (Log a)
forall a b. (a -> b) -> a -> b
$ a -> Log a
forall a. a -> Log a
Exp (a -> Log a) -> a -> Log a
forall a b. (a -> b) -> a -> b
$ a
m a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
log1p (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
cnt)
      SM.Skip    s
s2  SPEC -> a -> Int -> a -> s -> m (Log a)
sLoop2 SPEC
SM.SPEC a
m Int
cnt a
acc s
s2
      SM.Yield (Exp a
x) s
s2  SPEC -> a -> Int -> a -> s -> m (Log a)
sLoop2 SPEC
SM.SPEC a
m (Int
cntInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
expm1 (a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
m)) s
s2

-- | @log-sum-exp@ for streams, without incurring examining the stream twice,
-- but with the potential for numeric problems. In pricinple, the numeric error
-- of this function should be better than individual binary function
-- application and worse than an optimized @sum@ function.
--
-- Needs to be written in direct style, as otherwise any constructors (to tell
-- us if we collected two elements already) remain.

logsumexpS
   (Monad m, Ord a, Num a, Floating a)
   SM.Stream m a  m a
{-# Inline logsumexpS #-}
logsumexpS :: Stream m a -> m a
logsumexpS (SM.Stream s -> m (Step s a)
step s
s0) = SPEC -> s -> m a
lseLoop0 SPEC
SM.SPEC s
s0
  where
    lseLoop0 :: SPEC -> s -> m a
lseLoop0 SPEC
SM.SPEC s
s = s -> m (Step s a)
step s
s m (Step s a) -> (Step s a -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Step s a
SM.Done         a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
0
      SM.Skip    s
s0'  SPEC -> s -> m a
lseLoop0 SPEC
SM.SPEC s
s0'
      SM.Yield a
x s
s1   SPEC -> a -> s -> m a
lseLoop1 SPEC
SM.SPEC a
x s
s1
    lseLoop1 :: SPEC -> a -> s -> m a
lseLoop1 SPEC
SM.SPEC a
x s
s = s -> m (Step s a)
step s
s m (Step s a) -> (Step s a -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Step s a
SM.Done         a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
      SM.Skip    s
s1'  SPEC -> a -> s -> m a
lseLoop1 SPEC
SM.SPEC a
x s
s1'
      SM.Yield a
y s
sA   let !m :: a
m = a -> a -> a
forall a. Ord a => a -> a -> a
max a
x a
y in SPEC -> a -> a -> s -> m a
lseLoopAcc SPEC
SM.SPEC a
m (a -> a
forall a. Floating a => a -> a
exp (a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
m) a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (a
ya -> a -> a
forall a. Num a => a -> a -> a
-a
m)) s
sA
    lseLoopAcc :: SPEC -> a -> a -> s -> m a
lseLoopAcc SPEC
SM.SPEC !a
m !a
acc s
s = s -> m (Step s a)
step s
s m (Step s a) -> (Step s a -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Step s a
SM.Done         a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a
m a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
log a
acc
      SM.Skip    s
sA'  SPEC -> a -> a -> s -> m a
lseLoopAcc SPEC
SM.SPEC a
m a
acc s
sA'
      SM.Yield a
z s
sA'  SPEC -> a -> a -> s -> m a
lseLoopAcc SPEC
SM.SPEC a
m (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (a
za -> a -> a
forall a. Num a => a -> a -> a
-a
m)) s
sA'