{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Unsafe and often partial combinators intended for internal usage.
--
-- Handle with care.
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Sparse.Double
  ( Monomial(..)
  , emptyMonomial
  , addToMonomial
  , indices
  , SparseDouble(..)
  , apply
  , vars
  , d, d', ds
  , skeleton
  , spartial
  , partial
  , vgrad
  , vgrad'
  , vgrads
  , Grad(..)
  , Grads(..)
  , terms
  , primal
  ) where

import Prelude hiding (lookup)
import Control.Comonad.Cofree
import Control.Monad (join, guard)
import Data.Data
import Data.IntMap (IntMap, unionWith, findWithDefault, singleton, lookup)
import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Sparse.Common
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | We only store partials in sorted order, so the map contained in a partial
-- will only contain partials with equal or greater keys to that of the map in
-- which it was found. This should be key for efficiently computing sparse hessians.
-- there are only @n + k - 1@ choose @k@ distinct nth partial derivatives of a
-- function with k inputs.
data SparseDouble
  = Sparse {-# UNPACK #-} !Double (IntMap SparseDouble)
  | Zero
  deriving (Int -> SparseDouble -> ShowS
[SparseDouble] -> ShowS
SparseDouble -> String
(Int -> SparseDouble -> ShowS)
-> (SparseDouble -> String)
-> ([SparseDouble] -> ShowS)
-> Show SparseDouble
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SparseDouble] -> ShowS
$cshowList :: [SparseDouble] -> ShowS
show :: SparseDouble -> String
$cshow :: SparseDouble -> String
showsPrec :: Int -> SparseDouble -> ShowS
$cshowsPrec :: Int -> SparseDouble -> ShowS
Show, Typeable SparseDouble
DataType
Constr
Typeable SparseDouble
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> SparseDouble -> c SparseDouble)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c SparseDouble)
-> (SparseDouble -> Constr)
-> (SparseDouble -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c SparseDouble))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e))
    -> Maybe (c SparseDouble))
-> ((forall b. Data b => b -> b) -> SparseDouble -> SparseDouble)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> SparseDouble -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> SparseDouble -> r)
-> (forall u. (forall d. Data d => d -> u) -> SparseDouble -> [u])
-> (forall u.
    Int -> (forall d. Data d => d -> u) -> SparseDouble -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble)
-> Data SparseDouble
SparseDouble -> DataType
SparseDouble -> Constr
(forall b. Data b => b -> b) -> SparseDouble -> SparseDouble
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> SparseDouble -> u
forall u. (forall d. Data d => d -> u) -> SparseDouble -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c SparseDouble)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c SparseDouble)
$cZero :: Constr
$cSparse :: Constr
$tSparseDouble :: DataType
gmapMo :: (forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
$cgmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
gmapMp :: (forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
$cgmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
gmapM :: (forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
$cgmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
gmapQi :: Int -> (forall d. Data d => d -> u) -> SparseDouble -> u
$cgmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> SparseDouble -> u
gmapQ :: (forall d. Data d => d -> u) -> SparseDouble -> [u]
$cgmapQ :: forall u. (forall d. Data d => d -> u) -> SparseDouble -> [u]
gmapQr :: (r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
$cgmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
gmapQl :: (r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
$cgmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
gmapT :: (forall b. Data b => b -> b) -> SparseDouble -> SparseDouble
$cgmapT :: (forall b. Data b => b -> b) -> SparseDouble -> SparseDouble
dataCast2 :: (forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c SparseDouble)
$cdataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c SparseDouble)
dataCast1 :: (forall d. Data d => c (t d)) -> Maybe (c SparseDouble)
$cdataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c SparseDouble)
dataTypeOf :: SparseDouble -> DataType
$cdataTypeOf :: SparseDouble -> DataType
toConstr :: SparseDouble -> Constr
$ctoConstr :: SparseDouble -> Constr
gunfold :: (forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
$cgunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
gfoldl :: (forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
$cgfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
$cp1Data :: Typeable SparseDouble
Data, Typeable)

vars :: Traversable f => f Double -> f SparseDouble
vars :: f Double -> f SparseDouble
vars = (Int, f SparseDouble) -> f SparseDouble
forall a b. (a, b) -> b
snd ((Int, f SparseDouble) -> f SparseDouble)
-> (f Double -> (Int, f SparseDouble))
-> f Double
-> f SparseDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Double -> (Int, SparseDouble))
-> Int -> f Double -> (Int, f SparseDouble)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Int -> Double -> (Int, SparseDouble)
var Int
0 where
  var :: Int -> Double -> (Int, SparseDouble)
var !Int
n Double
a = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Double -> IntMap SparseDouble -> SparseDouble
Sparse Double
a (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ Int -> SparseDouble -> IntMap SparseDouble
forall a. Int -> a -> IntMap a
singleton Int
n (SparseDouble -> IntMap SparseDouble)
-> SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> a -> b
$ Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto Scalar SparseDouble
1)
{-# INLINE vars #-}

apply :: Traversable f => (f SparseDouble -> b) -> f Double -> b
apply :: (f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> b
f = f SparseDouble -> b
f (f SparseDouble -> b)
-> (f Double -> f SparseDouble) -> f Double -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f Double -> f SparseDouble
forall (f :: * -> *). Traversable f => f Double -> f SparseDouble
vars
{-# INLINE apply #-}

d :: Traversable f => f b -> SparseDouble -> f Double
d :: f b -> SparseDouble -> f Double
d f b
fs SparseDouble
Zero = Double
0 Double -> f b -> f Double
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs
d f b
fs (Sparse Double
_ IntMap SparseDouble
da) = (Int, f Double) -> f Double
forall a b. (a, b) -> b
snd ((Int, f Double) -> f Double) -> (Int, f Double) -> f Double
forall a b. (a -> b) -> a -> b
$ (Int -> b -> (Int, Double)) -> Int -> f b -> (Int, f Double)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL (\ !Int
n b
_ -> (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Double -> (SparseDouble -> Double) -> Maybe SparseDouble -> Double
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Double
0 SparseDouble -> Double
primal (Maybe SparseDouble -> Double) -> Maybe SparseDouble -> Double
forall a b. (a -> b) -> a -> b
$ Int -> IntMap SparseDouble -> Maybe SparseDouble
forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap SparseDouble
da)) Int
0 f b
fs
{-# INLINE d #-}

d' :: Traversable f => f Double -> SparseDouble -> (Double, f Double)
d' :: f Double -> SparseDouble -> (Double, f Double)
d' f Double
fs SparseDouble
Zero = (Double
0, Double
0 Double -> f Double -> f Double
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f Double
fs)
d' f Double
fs (Sparse Double
a IntMap SparseDouble
da) = (Double
a, (Int, f Double) -> f Double
forall a b. (a, b) -> b
snd ((Int, f Double) -> f Double) -> (Int, f Double) -> f Double
forall a b. (a -> b) -> a -> b
$ (Int -> Double -> (Int, Double))
-> Int -> f Double -> (Int, f Double)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL (\ !Int
n Double
_ -> (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Double -> (SparseDouble -> Double) -> Maybe SparseDouble -> Double
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Double
0 SparseDouble -> Double
primal (Maybe SparseDouble -> Double) -> Maybe SparseDouble -> Double
forall a b. (a -> b) -> a -> b
$ Int -> IntMap SparseDouble -> Maybe SparseDouble
forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap SparseDouble
da)) Int
0 f Double
fs)
{-# INLINE d' #-}

ds :: Traversable f => f b -> SparseDouble -> Cofree f Double
ds :: f b -> SparseDouble -> Cofree f Double
ds f b
fs SparseDouble
Zero = Cofree f Double
r where r :: Cofree f Double
r = Double
0 Double -> f (Cofree f Double) -> Cofree f Double
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Cofree f Double
r Cofree f Double -> f b -> f (Cofree f Double)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs)
ds f b
fs as :: SparseDouble
as@(Sparse Double
a IntMap SparseDouble
_) = Double
a Double -> f (Cofree f Double) -> Cofree f Double
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Int -> Cofree f Double
go Monomial
emptyMonomial (Int -> Cofree f Double) -> f Int -> f (Cofree f Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Int
fns) where
  fns :: f Int
fns = f b -> f Int
forall (f :: * -> *) a. Traversable f => f a -> f Int
skeleton f b
fs
  -- go :: Monomial -> Int -> Cofree f a
  go :: Monomial -> Int -> Cofree f Double
go Monomial
ix Int
i = [Int] -> SparseDouble -> Double
partial (Monomial -> [Int]
indices Monomial
ix') SparseDouble
as Double -> f (Cofree f Double) -> Cofree f Double
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Int -> Cofree f Double
go Monomial
ix' (Int -> Cofree f Double) -> f Int -> f (Cofree f Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Int
fns) where
    ix' :: Monomial
ix' = Int -> Monomial -> Monomial
addToMonomial Int
i Monomial
ix
{-# INLINE ds #-}

partialS :: [Int] -> SparseDouble -> SparseDouble
partialS :: [Int] -> SparseDouble -> SparseDouble
partialS []     SparseDouble
a             = SparseDouble
a
partialS (Int
n:[Int]
ns) (Sparse Double
_ IntMap SparseDouble
da) = [Int] -> SparseDouble -> SparseDouble
partialS [Int]
ns (SparseDouble -> SparseDouble) -> SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ SparseDouble -> Int -> IntMap SparseDouble -> SparseDouble
forall a. a -> Int -> IntMap a -> a
findWithDefault SparseDouble
Zero Int
n IntMap SparseDouble
da
partialS [Int]
_      SparseDouble
Zero          = SparseDouble
Zero
{-# INLINE partialS #-}

partial :: [Int] -> SparseDouble -> Double
partial :: [Int] -> SparseDouble -> Double
partial []     (Sparse Double
a IntMap SparseDouble
_)  = Double
a
partial (Int
n:[Int]
ns) (Sparse Double
_ IntMap SparseDouble
da) = [Int] -> SparseDouble -> Double
partial [Int]
ns (SparseDouble -> Double) -> SparseDouble -> Double
forall a b. (a -> b) -> a -> b
$ SparseDouble -> Int -> IntMap SparseDouble -> SparseDouble
forall a. a -> Int -> IntMap a -> a
findWithDefault (Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto Scalar SparseDouble
0) Int
n IntMap SparseDouble
da
partial [Int]
_      SparseDouble
Zero          = Double
0
{-# INLINE partial #-}

spartial :: [Int] -> SparseDouble -> Maybe Double
spartial :: [Int] -> SparseDouble -> Maybe Double
spartial [] (Sparse Double
a IntMap SparseDouble
_) = Double -> Maybe Double
forall a. a -> Maybe a
Just Double
a
spartial (Int
n:[Int]
ns) (Sparse Double
_ IntMap SparseDouble
da) = do
  SparseDouble
a' <- Int -> IntMap SparseDouble -> Maybe SparseDouble
forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap SparseDouble
da
  [Int] -> SparseDouble -> Maybe Double
spartial [Int]
ns SparseDouble
a'
spartial [Int]
_  SparseDouble
Zero         = Maybe Double
forall a. Maybe a
Nothing
{-# INLINE spartial #-}

primal :: SparseDouble -> Double
primal :: SparseDouble -> Double
primal (Sparse Double
a IntMap SparseDouble
_) = Double
a
primal SparseDouble
Zero = Double
0

instance Mode SparseDouble where
  type Scalar SparseDouble = Double

  auto :: Scalar SparseDouble -> SparseDouble
auto Scalar SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse Double
Scalar SparseDouble
a IntMap SparseDouble
forall a. IntMap a
IntMap.empty

  zero :: SparseDouble
zero = SparseDouble
Zero

  isKnownZero :: SparseDouble -> Bool
isKnownZero SparseDouble
Zero = Bool
True
  isKnownZero (Sparse Double
0 IntMap SparseDouble
m) = IntMap SparseDouble -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap SparseDouble
m
  isKnownZero SparseDouble
_ = Bool
False

  isKnownConstant :: SparseDouble -> Bool
isKnownConstant SparseDouble
Zero = Bool
True
  isKnownConstant (Sparse Double
_ IntMap SparseDouble
m) = IntMap SparseDouble -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap SparseDouble
m

  asKnownConstant :: SparseDouble -> Maybe (Scalar SparseDouble)
asKnownConstant SparseDouble
Zero = Double -> Maybe Double
forall a. a -> Maybe a
Just Double
0
  asKnownConstant (Sparse Double
a IntMap SparseDouble
m) = Double
a Double -> Maybe () -> Maybe Double
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (IntMap SparseDouble -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap SparseDouble
m)

  SparseDouble
Zero        ^* :: SparseDouble -> Scalar SparseDouble -> SparseDouble
^* Scalar SparseDouble
_ = SparseDouble
Zero
  Sparse Double
a IntMap SparseDouble
as ^* Scalar SparseDouble
b = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
Scalar SparseDouble
b) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SparseDouble -> Scalar SparseDouble -> SparseDouble
forall t. Mode t => t -> Scalar t -> t
^* Scalar SparseDouble
b) IntMap SparseDouble
as
  Scalar SparseDouble
_ *^ :: Scalar SparseDouble -> SparseDouble -> SparseDouble
*^ SparseDouble
Zero        = SparseDouble
Zero
  Scalar SparseDouble
a *^ Sparse Double
b IntMap SparseDouble
bs = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
Scalar SparseDouble
a Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
b) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar SparseDouble
a Scalar SparseDouble -> SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t -> t
*^) IntMap SparseDouble
bs

  SparseDouble
Zero        ^/ :: SparseDouble -> Scalar SparseDouble -> SparseDouble
^/ Scalar SparseDouble
_ = SparseDouble
Zero
  Sparse Double
a IntMap SparseDouble
as ^/ Scalar SparseDouble
b = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
a Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
Scalar SparseDouble
b) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SparseDouble -> Scalar SparseDouble -> SparseDouble
forall t. (Mode t, Fractional (Scalar t)) => t -> Scalar t -> t
^/ Scalar SparseDouble
b) IntMap SparseDouble
as

infixr 6 <+>

(<+>) :: SparseDouble -> SparseDouble -> SparseDouble
SparseDouble
Zero <+> :: SparseDouble -> SparseDouble -> SparseDouble
<+> SparseDouble
a = SparseDouble
a
SparseDouble
a <+> SparseDouble
Zero = SparseDouble
a
Sparse Double
a IntMap SparseDouble
as <+> Sparse Double
b IntMap SparseDouble
bs = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble -> SparseDouble)
-> IntMap SparseDouble
-> IntMap SparseDouble
-> IntMap SparseDouble
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>) IntMap SparseDouble
as IntMap SparseDouble
bs

-- The instances for Jacobian for Sparse and Tower are almost identical;
-- could easily be made exactly equal by small changes.
instance Jacobian SparseDouble where
  type D SparseDouble = SparseDouble
  unary :: (Scalar SparseDouble -> Scalar SparseDouble)
-> D SparseDouble -> SparseDouble -> SparseDouble
unary Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
_ SparseDouble
Zero = Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0)
  unary Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
dadb (Sparse Double
pb IntMap SparseDouble
bs) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
* D SparseDouble
SparseDouble
dadb) IntMap SparseDouble
bs

  lift1 :: (Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble -> D SparseDouble)
-> SparseDouble
-> SparseDouble
lift1 Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble
_ SparseDouble
Zero = Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0)
  lift1 Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
bs) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
* D SparseDouble -> D SparseDouble
df D SparseDouble
SparseDouble
b) IntMap SparseDouble
bs

  lift1_ :: (Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble -> D SparseDouble -> D SparseDouble)
-> SparseDouble
-> SparseDouble
lift1_ Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble -> D SparseDouble
_  SparseDouble
Zero = Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0)
  lift1_ Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble -> D SparseDouble
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
bs) = SparseDouble
a where
    a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble -> D SparseDouble -> D SparseDouble
df D SparseDouble
SparseDouble
a D SparseDouble
SparseDouble
b SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
bs

  binary :: (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble)
-> D SparseDouble
-> D SparseDouble
-> SparseDouble
-> SparseDouble
-> SparseDouble
binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
_    D SparseDouble
_    SparseDouble
Zero           SparseDouble
Zero           = Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0 Scalar SparseDouble
0)
  binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
_    D SparseDouble
dadc SparseDouble
Zero           (Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0  Double
Scalar SparseDouble
pc) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
SparseDouble
dadc SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc
  binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
dadb D SparseDouble
_    (Sparse Double
pb IntMap SparseDouble
db) SparseDouble
Zero           = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb Scalar SparseDouble
0 ) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
SparseDouble
dadb SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db
  binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
dadb D SparseDouble
dadc (Sparse Double
pb IntMap SparseDouble
db) (Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb Double
Scalar SparseDouble
pc) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$
    (SparseDouble -> SparseDouble -> SparseDouble)
-> IntMap SparseDouble
-> IntMap SparseDouble
-> IntMap SparseDouble
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>)  ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
SparseDouble
dadb SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
SparseDouble
dadc SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc)

  lift2 :: (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble
    -> D SparseDouble -> (D SparseDouble, D SparseDouble))
-> SparseDouble
-> SparseDouble
-> SparseDouble
lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
_  SparseDouble
Zero             SparseDouble
Zero = Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0 Scalar SparseDouble
0)
  lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df SparseDouble
Zero c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0 Double
Scalar SparseDouble
pc) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadc SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc where dadc :: SparseDouble
dadc = (SparseDouble, SparseDouble) -> SparseDouble
forall a b. (a, b) -> b
snd (D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df D SparseDouble
forall t. Mode t => t
zero D SparseDouble
SparseDouble
c)
  lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) SparseDouble
Zero = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb Scalar SparseDouble
0) (IntMap SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> SparseDouble
forall a b. (a -> b) -> a -> b
$ (SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
* SparseDouble
dadb) IntMap SparseDouble
db where dadb :: SparseDouble
dadb = (SparseDouble, SparseDouble) -> SparseDouble
forall a b. (a, b) -> a
fst (D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df D SparseDouble
SparseDouble
b D SparseDouble
forall t. Mode t => t
zero)
  lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb Double
Scalar SparseDouble
pc) IntMap SparseDouble
da where
    (SparseDouble
dadb, SparseDouble
dadc) = D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df D SparseDouble
SparseDouble
b D SparseDouble
SparseDouble
c
    da :: IntMap SparseDouble
da = (SparseDouble -> SparseDouble -> SparseDouble)
-> IntMap SparseDouble
-> IntMap SparseDouble
-> IntMap SparseDouble
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadb SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadc SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc)

  lift2_ :: (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble
    -> D SparseDouble
    -> D SparseDouble
    -> (D SparseDouble, D SparseDouble))
-> SparseDouble
-> SparseDouble
-> SparseDouble
lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
_  SparseDouble
Zero             SparseDouble
Zero = Scalar SparseDouble -> SparseDouble
forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0 Scalar SparseDouble
0)
  lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) SparseDouble
Zero = SparseDouble
a where a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb Scalar SparseDouble
0) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map ((SparseDouble, SparseDouble) -> SparseDouble
forall a b. (a, b) -> a
fst (D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df D SparseDouble
SparseDouble
a D SparseDouble
SparseDouble
b D SparseDouble
forall t. Mode t => t
zero) SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db)
  lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df SparseDouble
Zero c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = SparseDouble
a where a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Scalar SparseDouble
0 Double
Scalar SparseDouble
pc) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
* (SparseDouble, SparseDouble) -> SparseDouble
forall a b. (a, b) -> b
snd (D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df D SparseDouble
SparseDouble
a D SparseDouble
forall t. Mode t => t
zero D SparseDouble
SparseDouble
c)) IntMap SparseDouble
dc)
  lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = SparseDouble
a where
    (SparseDouble
dadb, SparseDouble
dadc) = D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df D SparseDouble
SparseDouble
a D SparseDouble
SparseDouble
b D SparseDouble
SparseDouble
c
    a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
Scalar SparseDouble
pb Double
Scalar SparseDouble
pc) IntMap SparseDouble
da
    da :: IntMap SparseDouble
da = (SparseDouble -> SparseDouble -> SparseDouble)
-> IntMap SparseDouble
-> IntMap SparseDouble
-> IntMap SparseDouble
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadb SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db) ((SparseDouble -> SparseDouble)
-> IntMap SparseDouble -> IntMap SparseDouble
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadc SparseDouble -> SparseDouble -> SparseDouble
forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc)

#define HEAD SparseDouble
#define BODY1(x)
#define BODY2(x,y)
#define NO_Bounded
#include "instances.h"

class Grad i o o' | i -> o o', o -> i o', o' -> i o where
  pack :: i -> [SparseDouble] -> SparseDouble
  unpack :: ([Double] -> [Double]) -> o
  unpack' :: ([Double] -> (Double, [Double])) -> o'

instance Grad SparseDouble [Double] (Double, [Double]) where
  pack :: SparseDouble -> [SparseDouble] -> SparseDouble
pack SparseDouble
i [SparseDouble]
_ = SparseDouble
i
  unpack :: ([Double] -> [Double]) -> [Double]
unpack [Double] -> [Double]
f = [Double] -> [Double]
f []
  unpack' :: ([Double] -> (Double, [Double])) -> (Double, [Double])
unpack' [Double] -> (Double, [Double])
f = [Double] -> (Double, [Double])
f []

instance Grad i o o' => Grad (SparseDouble -> i) (Double -> o) (Double -> o') where
  pack :: (SparseDouble -> i) -> [SparseDouble] -> SparseDouble
pack SparseDouble -> i
f (SparseDouble
a:[SparseDouble]
as) = i -> [SparseDouble] -> SparseDouble
forall i o o'. Grad i o o' => i -> [SparseDouble] -> SparseDouble
pack (SparseDouble -> i
f SparseDouble
a) [SparseDouble]
as
  pack SparseDouble -> i
_ [] = String -> SparseDouble
forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpack :: ([Double] -> [Double]) -> Double -> o
unpack [Double] -> [Double]
f Double
a = ([Double] -> [Double]) -> o
forall i o o'. Grad i o o' => ([Double] -> [Double]) -> o
unpack ([Double] -> [Double]
f ([Double] -> [Double])
-> ([Double] -> [Double]) -> [Double] -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
aDouble -> [Double] -> [Double]
forall a. a -> [a] -> [a]
:))
  unpack' :: ([Double] -> (Double, [Double])) -> Double -> o'
unpack' [Double] -> (Double, [Double])
f Double
a = ([Double] -> (Double, [Double])) -> o'
forall i o o'.
Grad i o o' =>
([Double] -> (Double, [Double])) -> o'
unpack' ([Double] -> (Double, [Double])
f ([Double] -> (Double, [Double]))
-> ([Double] -> [Double]) -> [Double] -> (Double, [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
aDouble -> [Double] -> [Double]
forall a. a -> [a] -> [a]
:))

vgrad :: Grad i o o' => i -> o
vgrad :: i -> o
vgrad i
i = ([Double] -> [Double]) -> o
forall i o o'. Grad i o o' => ([Double] -> [Double]) -> o
unpack (([SparseDouble] -> SparseDouble) -> [Double] -> [Double]
forall (f :: * -> *).
Traversable f =>
(f SparseDouble -> SparseDouble) -> f Double -> f Double
unsafeGrad (i -> [SparseDouble] -> SparseDouble
forall i o o'. Grad i o o' => i -> [SparseDouble] -> SparseDouble
pack i
i)) where
  unsafeGrad :: (f SparseDouble -> SparseDouble) -> f Double -> f Double
unsafeGrad f SparseDouble -> SparseDouble
f f Double
as = f Double -> SparseDouble -> f Double
forall (f :: * -> *) b.
Traversable f =>
f b -> SparseDouble -> f Double
d f Double
as (SparseDouble -> f Double) -> SparseDouble -> f Double
forall a b. (a -> b) -> a -> b
$ (f SparseDouble -> SparseDouble) -> f Double -> SparseDouble
forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> SparseDouble
f f Double
as
{-# INLINE vgrad #-}

vgrad' :: Grad i o o' => i -> o'
vgrad' :: i -> o'
vgrad' i
i = ([Double] -> (Double, [Double])) -> o'
forall i o o'.
Grad i o o' =>
([Double] -> (Double, [Double])) -> o'
unpack' (([SparseDouble] -> SparseDouble) -> [Double] -> (Double, [Double])
forall (f :: * -> *).
Traversable f =>
(f SparseDouble -> SparseDouble) -> f Double -> (Double, f Double)
unsafeGrad' (i -> [SparseDouble] -> SparseDouble
forall i o o'. Grad i o o' => i -> [SparseDouble] -> SparseDouble
pack i
i)) where
  unsafeGrad' :: (f SparseDouble -> SparseDouble) -> f Double -> (Double, f Double)
unsafeGrad' f SparseDouble -> SparseDouble
f f Double
as = f Double -> SparseDouble -> (Double, f Double)
forall (f :: * -> *).
Traversable f =>
f Double -> SparseDouble -> (Double, f Double)
d' f Double
as (SparseDouble -> (Double, f Double))
-> SparseDouble -> (Double, f Double)
forall a b. (a -> b) -> a -> b
$ (f SparseDouble -> SparseDouble) -> f Double -> SparseDouble
forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> SparseDouble
f f Double
as
{-# INLINE vgrad' #-}

class Grads i o | i -> o, o -> i where
  packs :: i -> [SparseDouble] -> SparseDouble
  unpacks :: ([Double] -> Cofree [] Double) -> o

instance Grads SparseDouble (Cofree [] Double) where
  packs :: SparseDouble -> [SparseDouble] -> SparseDouble
packs SparseDouble
i [SparseDouble]
_ = SparseDouble
i
  unpacks :: ([Double] -> Cofree [] Double) -> Cofree [] Double
unpacks [Double] -> Cofree [] Double
f = [Double] -> Cofree [] Double
f []

instance Grads i o => Grads (SparseDouble -> i) (Double -> o) where
  packs :: (SparseDouble -> i) -> [SparseDouble] -> SparseDouble
packs SparseDouble -> i
f (SparseDouble
a:[SparseDouble]
as) = i -> [SparseDouble] -> SparseDouble
forall i o. Grads i o => i -> [SparseDouble] -> SparseDouble
packs (SparseDouble -> i
f SparseDouble
a) [SparseDouble]
as
  packs SparseDouble -> i
_ [] = String -> SparseDouble
forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpacks :: ([Double] -> Cofree [] Double) -> Double -> o
unpacks [Double] -> Cofree [] Double
f Double
a = ([Double] -> Cofree [] Double) -> o
forall i o. Grads i o => ([Double] -> Cofree [] Double) -> o
unpacks ([Double] -> Cofree [] Double
f ([Double] -> Cofree [] Double)
-> ([Double] -> [Double]) -> [Double] -> Cofree [] Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
aDouble -> [Double] -> [Double]
forall a. a -> [a] -> [a]
:))

vgrads :: Grads i o => i -> o
vgrads :: i -> o
vgrads i
i = ([Double] -> Cofree [] Double) -> o
forall i o. Grads i o => ([Double] -> Cofree [] Double) -> o
unpacks (([SparseDouble] -> SparseDouble) -> [Double] -> Cofree [] Double
forall (f :: * -> *).
Traversable f =>
(f SparseDouble -> SparseDouble) -> f Double -> Cofree f Double
unsafeGrads (i -> [SparseDouble] -> SparseDouble
forall i o. Grads i o => i -> [SparseDouble] -> SparseDouble
packs i
i)) where
  unsafeGrads :: (f SparseDouble -> SparseDouble) -> f Double -> Cofree f Double
unsafeGrads f SparseDouble -> SparseDouble
f f Double
as = f Double -> SparseDouble -> Cofree f Double
forall (f :: * -> *) b.
Traversable f =>
f b -> SparseDouble -> Cofree f Double
ds f Double
as (SparseDouble -> Cofree f Double)
-> SparseDouble -> Cofree f Double
forall a b. (a -> b) -> a -> b
$ (f SparseDouble -> SparseDouble) -> f Double -> SparseDouble
forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> SparseDouble
f f Double
as
{-# INLINE vgrads #-}

isZero :: SparseDouble -> Bool
isZero :: SparseDouble -> Bool
isZero SparseDouble
Zero = Bool
True
isZero SparseDouble
_ = Bool
False

mul :: SparseDouble -> SparseDouble -> SparseDouble
mul :: SparseDouble -> SparseDouble -> SparseDouble
mul SparseDouble
Zero SparseDouble
_ = SparseDouble
Zero
mul SparseDouble
_ SparseDouble
Zero = SparseDouble
Zero
mul f :: SparseDouble
f@(Sparse Double
_ IntMap SparseDouble
am) g :: SparseDouble
g@(Sparse Double
_ IntMap SparseDouble
bm) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (SparseDouble -> Double
primal SparseDouble
f Double -> Double -> Double
forall a. Num a => a -> a -> a
* SparseDouble -> Double
primal SparseDouble
g) (Int -> Monomial -> IntMap SparseDouble
derivs Int
0 Monomial
emptyMonomial) where
  derivs :: Int -> Monomial -> IntMap SparseDouble
derivs Int
v Monomial
mi = [IntMap SparseDouble] -> IntMap SparseDouble
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IntMap.unions ((Int -> IntMap SparseDouble) -> [Int] -> [IntMap SparseDouble]
forall a b. (a -> b) -> [a] -> [b]
map Int -> IntMap SparseDouble
fn [Int
v..Int
kMax]) where
    fn :: Int -> IntMap SparseDouble
fn Int
w
      | [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
zs = IntMap SparseDouble
forall a. IntMap a
IntMap.empty
      | Bool
otherwise = Int -> SparseDouble -> IntMap SparseDouble
forall a. Int -> a -> IntMap a
IntMap.singleton Int
w (Double -> IntMap SparseDouble -> SparseDouble
Sparse ([Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Double]
ds) (Int -> Monomial -> IntMap SparseDouble
derivs Int
w Monomial
mi'))
      where
        mi' :: Monomial
mi' = Int -> Monomial -> Monomial
addToMonomial Int
w Monomial
mi
        ([Bool]
zs,[Double]
ds) = [(Bool, Double)] -> ([Bool], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (((Integer, Monomial, Monomial) -> (Bool, Double))
-> [(Integer, Monomial, Monomial)] -> [(Bool, Double)]
forall a b. (a -> b) -> [a] -> [b]
map (Integer, Monomial, Monomial) -> (Bool, Double)
derVal (Monomial -> [(Integer, Monomial, Monomial)]
terms Monomial
mi'))
        derVal :: (Integer, Monomial, Monomial) -> (Bool, Double)
derVal (Integer
bin,Monomial
mif,Monomial
mig) = (SparseDouble -> Bool
isZero SparseDouble
fder Bool -> Bool -> Bool
|| SparseDouble -> Bool
isZero SparseDouble
gder, Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
bin Double -> Double -> Double
forall a. Num a => a -> a -> a
* SparseDouble -> Double
primal SparseDouble
fder Double -> Double -> Double
forall a. Num a => a -> a -> a
* SparseDouble -> Double
primal SparseDouble
gder) where
          fder :: SparseDouble
fder = [Int] -> SparseDouble -> SparseDouble
partialS (Monomial -> [Int]
indices Monomial
mif) SparseDouble
f
          gder :: SparseDouble
gder = [Int] -> SparseDouble -> SparseDouble
partialS (Monomial -> [Int]
indices Monomial
mig) SparseDouble
g
  kMax :: Int
kMax = Int
-> (((Int, SparseDouble), IntMap SparseDouble) -> Int)
-> Maybe ((Int, SparseDouble), IntMap SparseDouble)
-> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) ((Int, SparseDouble) -> Int
forall a b. (a, b) -> a
fst((Int, SparseDouble) -> Int)
-> (((Int, SparseDouble), IntMap SparseDouble)
    -> (Int, SparseDouble))
-> ((Int, SparseDouble), IntMap SparseDouble)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Int, SparseDouble), IntMap SparseDouble) -> (Int, SparseDouble)
forall a b. (a, b) -> a
fst) (IntMap SparseDouble
-> Maybe ((Int, SparseDouble), IntMap SparseDouble)
forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IntMap.maxViewWithKey IntMap SparseDouble
am) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
-> (((Int, SparseDouble), IntMap SparseDouble) -> Int)
-> Maybe ((Int, SparseDouble), IntMap SparseDouble)
-> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) ((Int, SparseDouble) -> Int
forall a b. (a, b) -> a
fst((Int, SparseDouble) -> Int)
-> (((Int, SparseDouble), IntMap SparseDouble)
    -> (Int, SparseDouble))
-> ((Int, SparseDouble), IntMap SparseDouble)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Int, SparseDouble), IntMap SparseDouble) -> (Int, SparseDouble)
forall a b. (a, b) -> a
fst) (IntMap SparseDouble
-> Maybe ((Int, SparseDouble), IntMap SparseDouble)
forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IntMap.maxViewWithKey IntMap SparseDouble
bm)