{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}

#ifdef AD_FFI
{-# LANGUAGE ForeignFunctionInterface #-}
#endif

module Numeric.AD.Internal.Reverse.Double
  ( ReverseDouble(..)
  , Tape(..)
  , reifyTape
  , reifyTypeableTape
  , partials
  , partialArrayOf
  , partialMapOf
  , derivativeOf
  , derivativeOf'
  , bind
  , unbind
  , unbindMap
  , unbindWith
  , unbindMapWithDefault
  , var
  , varId
  , primal
  ) where

#ifdef AD_FFI
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.C.Types
import qualified Foreign.Marshal.Array as MA
import qualified Foreign.Marshal.Alloc as MA
#else
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unsafe as Unsafe
import Data.IORef
import Unsafe.Coerce
#endif

import Data.Functor
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import Data.Array
import Data.IntMap (IntMap, fromDistinctAscList, findWithDefault)
import Data.Number.Erf
import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)

#ifdef AD_FFI

newtype Tape = Tape { getTape :: ForeignPtr Tape }

foreign import ccall unsafe "tape_alloc" c_tape_alloc :: CInt -> CInt -> IO (Ptr Tape)
foreign import ccall unsafe "tape_push" c_tape_push :: Ptr Tape -> CInt -> CInt -> Double -> Double -> IO Int
foreign import ccall unsafe "tape_backPropagate" c_tape_backPropagate :: Ptr Tape -> CInt -> Ptr Double -> IO ()
foreign import ccall unsafe "tape_variables" c_tape_variables :: Ptr Tape -> IO CInt
foreign import ccall unsafe "&tape_free" c_ref_tape_free :: FinalizerPtr Tape

pushTape :: Reifies s Tape => p s -> Int -> Int -> Double -> Double -> IO Int
pushTape p i1 i2 d1 d2 = do
  withForeignPtr (getTape (reflect p)) $ \tape -> 
    c_tape_push tape (fromIntegral i1) (fromIntegral i2) d1 d2
{-# INLINE pushTape #-}

-- | Extract the partials from the current chain for a given AD variable.
partials :: forall s. (Reifies s Tape) => ReverseDouble s -> [Double]
partials Zero        = []
partials (Lift _)    = []
partials (ReverseDouble k _) = unsafePerformIO $
  withForeignPtr (getTape (reflect (Proxy :: Proxy s))) $ \tape -> do
    l <- fromIntegral <$> c_tape_variables tape
    arr <- MA.mallocArray l
    c_tape_backPropagate tape (fromIntegral k) arr
    ps <- MA.peekArray l arr
    MA.free arr
    return ps
{-# INLINE partials #-}

newTape :: Int -> IO Tape
newTape vs = do
  p <- c_tape_alloc (fromIntegral vs) (4 * 1024)
  Tape <$> newForeignPtr c_ref_tape_free p

-- | Construct a tape that starts with @n@ variables.
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape vs k = unsafePerformIO $ fmap (\t -> reify t k) (newTape vs)
{-# NOINLINE reifyTape #-}

-- | Construct a tape that starts with @n@ variables.
reifyTypeableTape :: Int -> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape vs k = unsafePerformIO $ fmap (\t -> reifyTypeable t k) (newTape vs)
{-# NOINLINE reifyTypeableTape #-}

-- | This is used to create a new entry on the chain given a unary function, its derivative with respect to its input,
-- the variable ID of its input, and the value of its input. Used by 'unary' and 'binary' internally.
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily f di i b = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i 0 di 0.0)) $! f b
{-# INLINE unarily #-}

-- | This is used to create a new entry on the chain given a binary function, its derivatives with respect to its inputs,
-- their variable IDs and values. Used by 'binary' internally.
binarily :: forall s. Reifies s Tape => (Double -> Double -> Double) -> Double -> Double -> Int -> Double -> Int -> Double -> ReverseDouble s
binarily f di dj i b j c = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i j di dj)) $! f b c
{-# INLINE binarily #-}

#else

data Cells where
  Nil    :: Cells
  Unary  :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> !Cells -> Cells
  Binary :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> {-# UNPACK #-} !Double -> !Cells -> Cells

dropCells :: Int -> Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells Int
0 Cells
xs = Cells
xs
dropCells Int
_ Cells
Nil = Cells
Nil
dropCells Int
n (Unary Int
_ Double
_ Cells
xs)      = (Int -> Cells -> Cells
dropCells (Int -> Cells -> Cells) -> Int -> Cells -> Cells
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs
dropCells Int
n (Binary Int
_ Int
_ Double
_ Double
_ Cells
xs) = (Int -> Cells -> Cells
dropCells (Int -> Cells -> Cells) -> Int -> Cells -> Cells
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs

data Head = Head {-# UNPACK #-} !Int !Cells

newtype Tape = Tape { Tape -> IORef Head
getTape :: IORef Head }

-- | Used internally to push sensitivities down the chain.
backPropagate :: Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate :: Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate Int
k Cells
Nil STArray s Int Double
_ = Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
k
backPropagate Int
k (Unary Int
i Double
g Cells
xs) STArray s Int Double
ss = do
  Double
da <- STArray s Int Double -> Int -> ST s Double
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
k
  Double
db <- STArray s Int Double -> Int -> ST s Double
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
i
  STArray s Int Double -> Int -> Double -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
i (Double -> ST s ()) -> Double -> ST s ()
forall a b. (a -> b) -> a -> b
$! Double
db Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a b. a -> b
unsafeCoerce Double
gDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
da
  (Int -> Cells -> STArray s Int Double -> ST s Int
forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate (Int -> Cells -> STArray s Int Double -> ST s Int)
-> Int -> Cells -> STArray s Int Double -> ST s Int
forall a b. (a -> b) -> a -> b
$! Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int Double
ss
backPropagate Int
k (Binary Int
i Int
j Double
g Double
h Cells
xs) STArray s Int Double
ss = do
  Double
da <- STArray s Int Double -> Int -> ST s Double
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
k
  Double
db <- STArray s Int Double -> Int -> ST s Double
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
i
  STArray s Int Double -> Int -> Double -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
i (Double -> ST s ()) -> Double -> ST s ()
forall a b. (a -> b) -> a -> b
$! Double
db Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a b. a -> b
unsafeCoerce Double
gDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
da
  Double
dc <- STArray s Int Double -> Int -> ST s Double
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
j
  STArray s Int Double -> Int -> Double -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
j (Double -> ST s ()) -> Double -> ST s ()
forall a b. (a -> b) -> a -> b
$! Double
dc Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a b. a -> b
unsafeCoerce Double
hDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
da
  (Int -> Cells -> STArray s Int Double -> ST s Int
forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate (Int -> Cells -> STArray s Int Double -> ST s Int)
-> Int -> Cells -> STArray s Int Double -> ST s Int
forall a b. (a -> b) -> a -> b
$! Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int Double
ss

-- | Extract the partials from the current chain for a given AD variable.
partials :: forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials :: ReverseDouble s -> [Double]
partials ReverseDouble s
Zero        = []
partials (Lift Double
_)    = []
partials (ReverseDouble Int
k Double
_) = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Array Int Double
sensitivities Array Int Double -> Int -> Double
forall i e. Ix i => Array i e -> i -> e
!) [Int
0..Int
vs] where
  Head Int
n Cells
t = IO Head -> Head
forall a. IO a -> a
unsafePerformIO (IO Head -> Head) -> IO Head -> Head
forall a b. (a -> b) -> a -> b
$ IORef Head -> IO Head
forall a. IORef a -> IO a
readIORef (Tape -> IORef Head
getTape (Proxy s -> Tape
forall k (s :: k) a (proxy :: k -> *). Reifies s a => proxy s -> a
reflect (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s)))
  tk :: Cells
tk = Int -> Cells -> Cells
dropCells (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) Cells
t
  (Int
vs,Array Int Double
sensitivities) = (forall s. ST s (Int, Array Int Double)) -> (Int, Array Int Double)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Int, Array Int Double))
 -> (Int, Array Int Double))
-> (forall s. ST s (Int, Array Int Double))
-> (Int, Array Int Double)
forall a b. (a -> b) -> a -> b
$ do
    STArray s Int Double
ss <- (Int, Int) -> Double -> ST s (STArray s Int Double)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
k) Double
0
    STArray s Int Double -> Int -> Double -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
k Double
1
    Int
v <- Int -> Cells -> STArray s Int Double -> ST s Int
forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate Int
k Cells
tk STArray s Int Double
ss
    Array Int Double
as <- STArray s Int Double -> ST s (Array Int Double)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
Unsafe.unsafeFreeze STArray s Int Double
ss
    (Int, Array Int Double) -> ST s (Int, Array Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
v, Array Int Double
as)

-- | Construct a tape that starts with @n@ variables.
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape Int
vs forall s. Reifies s Tape => Proxy s -> r
k = IO r -> r
forall a. IO a -> a
unsafePerformIO (IO r -> r) -> IO r -> r
forall a b. (a -> b) -> a -> b
$ do
  IORef Head
h <- Head -> IO (IORef Head)
forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
  r -> IO r
forall (m :: * -> *) a. Monad m => a -> m a
return (Tape -> (forall s. Reifies s Tape => Proxy s -> r) -> r
forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify (IORef Head -> Tape
Tape IORef Head
h) forall s. Reifies s Tape => Proxy s -> r
k)
{-# NOINLINE reifyTape #-}

-- | Construct a tape that starts with @n@ variables.
reifyTypeableTape :: Int -> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape :: Int
-> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape Int
vs forall s. (Reifies s Tape, Typeable s) => Proxy s -> r
k = IO r -> r
forall a. IO a -> a
unsafePerformIO (IO r -> r) -> IO r -> r
forall a b. (a -> b) -> a -> b
$ do
  IORef Head
h <- Head -> IO (IORef Head)
forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
  r -> IO r
forall (m :: * -> *) a. Monad m => a -> m a
return (Tape
-> (forall s. (Typeable s, Reifies s Tape) => Proxy s -> r) -> r
forall a r.
Typeable a =>
a -> (forall s. (Typeable s, Reifies s a) => Proxy s -> r) -> r
reifyTypeable (IORef Head -> Tape
Tape IORef Head
h) forall s. (Typeable s, Reifies s Tape) => Proxy s -> r
forall s. (Reifies s Tape, Typeable s) => Proxy s -> r
k)
{-# NOINLINE reifyTypeableTape #-}

un :: Int -> Double -> Head -> (Head, Int)
un :: Int -> Double -> Head -> (Head, Int)
un Int
i Double
di (Head Int
r Cells
t) = Head
h Head -> (Head, Int) -> (Head, Int)
`seq` Int
r' Int -> (Head, Int) -> (Head, Int)
`seq` (Head
h, Int
r') where
  r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Double -> Cells -> Cells
Unary Int
i Double
di Cells
t)
{-# INLINE un #-}

bin :: Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin :: Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin Int
i Int
j Double
di Double
dj (Head Int
r Cells
t) = Head
h Head -> (Head, Int) -> (Head, Int)
`seq` Int
r' Int -> (Head, Int) -> (Head, Int)
`seq` (Head
h, Int
r') where
  r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Int -> Double -> Double -> Cells -> Cells
Binary Int
i Int
j Double
di Double
dj Cells
t)
{-# INLINE bin #-}

modifyTape :: Reifies s Tape => p s -> (Head -> (Head, r)) -> IO r
modifyTape :: p s -> (Head -> (Head, r)) -> IO r
modifyTape p s
p = IORef Head -> (Head -> (Head, r)) -> IO r
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (Tape -> IORef Head
getTape (p s -> Tape
forall k (s :: k) a (proxy :: k -> *). Reifies s a => proxy s -> a
reflect p s
p))
{-# INLINE modifyTape #-}

-- | This is used to create a new entry on the chain given a unary function, its derivative with respect to its input,
-- the variable ID of its input, and the value of its input. Used by 'unary' and 'binary' internally.
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily :: (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily Double -> Double
f Double
di Int
i Double
b = Int -> Double -> ReverseDouble s
forall s. Int -> Double -> ReverseDouble s
ReverseDouble (IO Int -> Int
forall a. IO a -> a
unsafePerformIO (Proxy s -> (Head -> (Head, Int)) -> IO Int
forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Double -> Head -> (Head, Int)
un Int
i Double
di))) (Double -> ReverseDouble s) -> Double -> ReverseDouble s
forall a b. (a -> b) -> a -> b
$! Double -> Double
f Double
b
{-# INLINE unarily #-}

-- | This is used to create a new entry on the chain given a binary function, its derivatives with respect to its inputs,
-- their variable IDs and values. Used by 'binary' internally.
binarily :: forall s. Reifies s Tape => (Double -> Double -> Double) -> Double -> Double -> Int -> Double -> Int -> Double -> ReverseDouble s
binarily :: (Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
binarily Double -> Double -> Double
f Double
di Double
dj Int
i Double
b Int
j Double
c = Int -> Double -> ReverseDouble s
forall s. Int -> Double -> ReverseDouble s
ReverseDouble (IO Int -> Int
forall a. IO a -> a
unsafePerformIO (Proxy s -> (Head -> (Head, Int)) -> IO Int
forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin Int
i Int
j Double
di Double
dj))) (Double -> ReverseDouble s) -> Double -> ReverseDouble s
forall a b. (a -> b) -> a -> b
$! Double -> Double -> Double
f Double
b Double
c
{-# INLINE binarily #-}

#endif

data ReverseDouble s where
  Zero :: ReverseDouble s
  Lift :: {-# UNPACK #-} !Double -> ReverseDouble s
  ReverseDouble :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> ReverseDouble s
  deriving (Int -> ReverseDouble s -> ShowS
[ReverseDouble s] -> ShowS
ReverseDouble s -> String
(Int -> ReverseDouble s -> ShowS)
-> (ReverseDouble s -> String)
-> ([ReverseDouble s] -> ShowS)
-> Show (ReverseDouble s)
forall s. Int -> ReverseDouble s -> ShowS
forall s. [ReverseDouble s] -> ShowS
forall s. ReverseDouble s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReverseDouble s] -> ShowS
$cshowList :: forall s. [ReverseDouble s] -> ShowS
show :: ReverseDouble s -> String
$cshow :: forall s. ReverseDouble s -> String
showsPrec :: Int -> ReverseDouble s -> ShowS
$cshowsPrec :: forall s. Int -> ReverseDouble s -> ShowS
Show, Typeable)

instance Reifies s Tape => Mode (ReverseDouble s) where
  type Scalar (ReverseDouble s) = Double

  isKnownZero :: ReverseDouble s -> Bool
isKnownZero ReverseDouble s
Zero = Bool
True
  isKnownZero (Lift Double
0) = Bool
True
  isKnownZero ReverseDouble s
_    = Bool
False

  isKnownConstant :: ReverseDouble s -> Bool
isKnownConstant ReverseDouble{} = Bool
False
  isKnownConstant ReverseDouble s
_ = Bool
True

  auto :: Scalar (ReverseDouble s) -> ReverseDouble s
auto = Scalar (ReverseDouble s) -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift
  zero :: ReverseDouble s
zero = ReverseDouble s
forall s. ReverseDouble s
Zero
  Scalar (ReverseDouble s)
a *^ :: Scalar (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
*^ ReverseDouble s
b = (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Double
Scalar (ReverseDouble s)
a Double -> Double -> Double
forall a. Num a => a -> a -> a
*) (\D (ReverseDouble s)
_ -> Scalar (Id Double) -> Id Double
forall t. Mode t => Scalar t -> t
auto Scalar (Id Double)
Scalar (ReverseDouble s)
a) ReverseDouble s
b
  ReverseDouble s
a ^* :: ReverseDouble s -> Scalar (ReverseDouble s) -> ReverseDouble s
^* Scalar (ReverseDouble s)
b = (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
Scalar (ReverseDouble s)
b) (\D (ReverseDouble s)
_ -> Scalar (Id Double) -> Id Double
forall t. Mode t => Scalar t -> t
auto Scalar (Id Double)
Scalar (ReverseDouble s)
b) ReverseDouble s
a
  ReverseDouble s
a ^/ :: ReverseDouble s -> Scalar (ReverseDouble s) -> ReverseDouble s
^/ Scalar (ReverseDouble s)
b = (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
Scalar (ReverseDouble s)
b) (\D (ReverseDouble s)
_ -> Scalar (Id Double) -> Id Double
forall t. Mode t => Scalar t -> t
auto (Double -> Double
forall a. Fractional a => a -> a
recip Double
Scalar (ReverseDouble s)
b)) ReverseDouble s
a

(<+>) :: Reifies s Tape => ReverseDouble s -> ReverseDouble s -> ReverseDouble s
<+> :: ReverseDouble s -> ReverseDouble s -> ReverseDouble s
(<+>)  = (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
forall a. Num a => a -> a -> a
(+) D (ReverseDouble s)
1 D (ReverseDouble s)
1

primal :: ReverseDouble s -> Double
primal :: ReverseDouble s -> Double
primal ReverseDouble s
Zero = Double
0
primal (Lift Double
a) = Double
a
primal (ReverseDouble Int
_ Double
a) = Double
a

instance Reifies s Tape => Jacobian (ReverseDouble s) where
  type D (ReverseDouble s) = Id Double

  unary :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_          ReverseDouble s
Zero    = Double -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Scalar (ReverseDouble s)
0)
  unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         (Lift Double
a) = Double -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
Scalar (ReverseDouble s)
a)
  unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id dadi) (ReverseDouble Int
i Double
b) = (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily Double -> Double
Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
dadi Int
i Double
b

  lift1 :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
lift1 Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s) -> D (ReverseDouble s)
df ReverseDouble s
b = (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (D (ReverseDouble s) -> D (ReverseDouble s)
df (Double -> Id Double
forall a. a -> Id a
Id Double
pb)) ReverseDouble s
b where
    pb :: Double
pb = ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
b

  lift1_ :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
lift1_ Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s) -> D (ReverseDouble s) -> D (ReverseDouble s)
df ReverseDouble s
b = (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (Double -> Double -> Double
forall a b. a -> b -> a
const Double
Scalar (ReverseDouble s)
a) (D (ReverseDouble s) -> D (ReverseDouble s) -> D (ReverseDouble s)
df (Double -> Id Double
forall a. a -> Id a
Id Double
Scalar (ReverseDouble s)
a) (Double -> Id Double
forall a. a -> Id a
Id Double
pb)) ReverseDouble s
b where
    pb :: Double
pb = ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
    a :: Scalar (ReverseDouble s)
a = Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
Scalar (ReverseDouble s)
pb

  binary :: (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         ReverseDouble s
Zero     ReverseDouble s
Zero     = Double -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Scalar (ReverseDouble s)
0 Scalar (ReverseDouble s)
0)
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         ReverseDouble s
Zero     (Lift Double
c) = Double -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Scalar (ReverseDouble s)
0 Double
Scalar (ReverseDouble s)
c)
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         (Lift Double
b) ReverseDouble s
Zero     = Double -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
Scalar (ReverseDouble s)
b Scalar (ReverseDouble s)
0)
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         (Lift Double
b) (Lift Double
c) = Double -> ReverseDouble s
forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
Scalar (ReverseDouble s)
b Double
Scalar (ReverseDouble s)
c)

  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         (Id dadc) ReverseDouble s
Zero        (ReverseDouble Int
i Double
c) = (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Scalar (ReverseDouble s)
0) Double
dadc Int
i Double
c
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         (Id dadc) (Lift Double
b)    (ReverseDouble Int
i Double
c) = (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
Scalar (ReverseDouble s)
b) Double
dadc Int
i Double
c
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id dadb) D (ReverseDouble s)
_         (ReverseDouble Int
i Double
b) ReverseDouble s
Zero        = (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
`f` Scalar (ReverseDouble s)
0) Double
dadb Int
i Double
b
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id dadb) D (ReverseDouble s)
_         (ReverseDouble Int
i Double
b) (Lift Double
c)    = (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
`f` Double
Scalar (ReverseDouble s)
c) Double
dadb Int
i Double
b
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id dadb) (Id dadc) (ReverseDouble Int
i Double
b) (ReverseDouble Int
j Double
c) = (Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
forall s.
Reifies s Tape =>
(Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
binarily Double -> Double -> Double
Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
dadb Double
dadc Int
i Double
b Int
j Double
c

  lift2 :: (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
lift2 Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df ReverseDouble s
b ReverseDouble s
c = (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
Id Double
dadb D (ReverseDouble s)
Id Double
dadc ReverseDouble s
b ReverseDouble s
c where
    (Id Double
dadb, Id Double
dadc) = D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df (Double -> Id Double
forall a. a -> Id a
Id (ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
b)) (Double -> Id Double
forall a. a -> Id a
Id (ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
c))

  lift2_ :: (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
lift2_ Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df ReverseDouble s
b ReverseDouble s
c = (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (ReverseDouble s)
_ Scalar (ReverseDouble s)
_ -> Scalar (ReverseDouble s)
a) D (ReverseDouble s)
Id Double
dadb D (ReverseDouble s)
Id Double
dadc ReverseDouble s
b ReverseDouble s
c where
    pb :: Double
pb = ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
    pc :: Double
pc = ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
c
    a :: Scalar (ReverseDouble s)
a = Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
Scalar (ReverseDouble s)
pb Double
Scalar (ReverseDouble s)
pc
    (Id Double
dadb, Id Double
dadc) = D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df (Double -> Id Double
forall a. a -> Id a
Id Double
Scalar (ReverseDouble s)
a) (Double -> Id Double
forall a. a -> Id a
Id Double
pb) (Double -> Id Double
forall a. a -> Id a
Id Double
pc)

mul :: Reifies s Tape => ReverseDouble s -> ReverseDouble s -> ReverseDouble s
mul :: ReverseDouble s -> ReverseDouble s -> ReverseDouble s
mul = (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
forall a. Num a => a -> a -> a
(*) (\D (ReverseDouble s)
x D (ReverseDouble s)
y -> (D (ReverseDouble s)
y, D (ReverseDouble s)
x))

#define BODY1(x) Reifies s Tape =>
#define BODY2(x,y) Reifies s Tape =>
#define HEAD (ReverseDouble s)
#define NO_Bounded
#include "instances.h"

-- | Helper that extracts the derivative of a chain when the chain was constructed with 1 variable.
derivativeOf :: Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf :: Proxy s -> ReverseDouble s -> Double
derivativeOf Proxy s
_ = [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Double] -> Double)
-> (ReverseDouble s -> [Double]) -> ReverseDouble s -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReverseDouble s -> [Double]
forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE derivativeOf #-}

-- | Helper that extracts both the primal and derivative of a chain when the chain was constructed with 1 variable.
derivativeOf' :: Reifies s Tape => Proxy s -> ReverseDouble s -> (Double, Double)
derivativeOf' :: Proxy s -> ReverseDouble s -> (Double, Double)
derivativeOf' Proxy s
p ReverseDouble s
r = (ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
r, Proxy s -> ReverseDouble s -> Double
forall s. Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf Proxy s
p ReverseDouble s
r)
{-# INLINE derivativeOf' #-}


-- | Return an 'Array' of 'partials' given bounds for the variable IDs.
partialArrayOf :: Reifies s Tape => Proxy s -> (Int, Int) -> ReverseDouble s -> Array Int Double
partialArrayOf :: Proxy s -> (Int, Int) -> ReverseDouble s -> Array Int Double
partialArrayOf Proxy s
_ (Int, Int)
vbounds = (Double -> Double -> Double)
-> Double -> (Int, Int) -> [(Int, Double)] -> Array Int Double
forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) Double
0 (Int, Int)
vbounds ([(Int, Double)] -> Array Int Double)
-> (ReverseDouble s -> [(Int, Double)])
-> ReverseDouble s
-> Array Int Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Double] -> [(Int, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([Double] -> [(Int, Double)])
-> (ReverseDouble s -> [Double])
-> ReverseDouble s
-> [(Int, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReverseDouble s -> [Double]
forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE partialArrayOf #-}

-- | Return an 'IntMap' of sparse partials
partialMapOf :: Reifies s Tape => Proxy s -> ReverseDouble s-> IntMap Double
partialMapOf :: Proxy s -> ReverseDouble s -> IntMap Double
partialMapOf Proxy s
_ = [(Int, Double)] -> IntMap Double
forall a. [(Int, a)] -> IntMap a
fromDistinctAscList ([(Int, Double)] -> IntMap Double)
-> (ReverseDouble s -> [(Int, Double)])
-> ReverseDouble s
-> IntMap Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Double] -> [(Int, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([Double] -> [(Int, Double)])
-> (ReverseDouble s -> [Double])
-> ReverseDouble s
-> [(Int, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReverseDouble s -> [Double]
forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE partialMapOf #-}

var :: Double -> Int -> ReverseDouble s
var :: Double -> Int -> ReverseDouble s
var Double
a Int
v = Int -> Double -> ReverseDouble s
forall s. Int -> Double -> ReverseDouble s
ReverseDouble Int
v Double
a

varId :: ReverseDouble s -> Int
varId :: ReverseDouble s -> Int
varId (ReverseDouble Int
v Double
_) = Int
v
varId ReverseDouble s
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"varId: not a Var"

bind :: Traversable f => f Double -> (f (ReverseDouble s), (Int,Int))
bind :: f Double -> (f (ReverseDouble s), (Int, Int))
bind f Double
xs = (f (ReverseDouble s)
r,(Int
0,Int
hi)) where
  (f (ReverseDouble s)
r,Int
hi) = State Int (f (ReverseDouble s))
-> Int -> (f (ReverseDouble s), Int)
forall s a. State s a -> s -> (a, s)
runState ((Double -> StateT Int Identity (ReverseDouble s))
-> f Double -> State Int (f (ReverseDouble s))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Double -> StateT Int Identity (ReverseDouble s)
forall (m :: * -> *) s.
Monad m =>
Double -> StateT Int m (ReverseDouble s)
freshVar f Double
xs) Int
0
  freshVar :: Double -> StateT Int m (ReverseDouble s)
freshVar Double
a = (Int -> (ReverseDouble s, Int)) -> StateT Int m (ReverseDouble s)
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> (ReverseDouble s, Int)) -> StateT Int m (ReverseDouble s))
-> (Int -> (ReverseDouble s, Int))
-> StateT Int m (ReverseDouble s)
forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
s' Int -> (ReverseDouble s, Int) -> (ReverseDouble s, Int)
`seq` (Double -> Int -> ReverseDouble s
forall s. Double -> Int -> ReverseDouble s
var Double
a Int
s, Int
s')

unbind :: Functor f => f (ReverseDouble s) -> Array Int Double -> f Double
unbind :: f (ReverseDouble s) -> Array Int Double -> f Double
unbind f (ReverseDouble s)
xs Array Int Double
ys = (ReverseDouble s -> Double) -> f (ReverseDouble s) -> f Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Array Int Double
ys Array Int Double -> Int -> Double
forall i e. Ix i => Array i e -> i -> e
! ReverseDouble s -> Int
forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) f (ReverseDouble s)
xs

unbindWith :: Functor f => (Double -> b -> c) -> f (ReverseDouble s) -> Array Int b -> f c
unbindWith :: (Double -> b -> c) -> f (ReverseDouble s) -> Array Int b -> f c
unbindWith Double -> b -> c
f f (ReverseDouble s)
xs Array Int b
ys = (ReverseDouble s -> c) -> f (ReverseDouble s) -> f c
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> b -> c
f (ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
v) (Array Int b
ys Array Int b -> Int -> b
forall i e. Ix i => Array i e -> i -> e
! ReverseDouble s -> Int
forall s. ReverseDouble s -> Int
varId ReverseDouble s
v)) f (ReverseDouble s)
xs

unbindMap :: Functor f => f (ReverseDouble s) -> IntMap Double -> f Double
unbindMap :: f (ReverseDouble s) -> IntMap Double -> f Double
unbindMap f (ReverseDouble s)
xs IntMap Double
ys = (ReverseDouble s -> Double) -> f (ReverseDouble s) -> f Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> Int -> IntMap Double -> Double
forall a. a -> Int -> IntMap a -> a
findWithDefault Double
0 (ReverseDouble s -> Int
forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) IntMap Double
ys) f (ReverseDouble s)
xs

unbindMapWithDefault :: Functor f => b -> (Double -> b -> c) -> f (ReverseDouble s) -> IntMap b -> f c
unbindMapWithDefault :: b -> (Double -> b -> c) -> f (ReverseDouble s) -> IntMap b -> f c
unbindMapWithDefault b
z Double -> b -> c
f f (ReverseDouble s)
xs IntMap b
ys = (ReverseDouble s -> c) -> f (ReverseDouble s) -> f c
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> b -> c
f (ReverseDouble s -> Double
forall s. ReverseDouble s -> Double
primal ReverseDouble s
v) (b -> c) -> b -> c
forall a b. (a -> b) -> a -> b
$ b -> Int -> IntMap b -> b
forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (ReverseDouble s -> Int
forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) IntMap b
ys) f (ReverseDouble s)
xs