{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Unsafe and often partial combinators intended for internal usage.
--
-- Handle with care.
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Sparse.Double
  ( Monomial(..)
  , emptyMonomial
  , addToMonomial
  , indices
  , SparseDouble(..)
  , apply
  , vars
  , d, d', ds
  , skeleton
  , spartial
  , partial
  , vgrad
  , vgrad'
  , vgrads
  , Grad(..)
  , Grads(..)
  , terms
  , primal
  ) where

import Prelude hiding (lookup)
import Control.Comonad.Cofree
import Control.Monad (join, guard)
import Data.Data
import Data.IntMap (IntMap, unionWith, findWithDefault, singleton, lookup)
import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Sparse.Common
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | We only store partials in sorted order, so the map contained in a partial
-- will only contain partials with equal or greater keys to that of the map in
-- which it was found. This should be key for efficiently computing sparse hessians.
-- there are only @n + k - 1@ choose @k@ distinct nth partial derivatives of a
-- function with k inputs.
data SparseDouble
  = Sparse {-# UNPACK #-} !Double (IntMap SparseDouble)
  | Zero
  deriving (Key -> SparseDouble -> ShowS
[SparseDouble] -> ShowS
SparseDouble -> String
forall a.
(Key -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SparseDouble] -> ShowS
$cshowList :: [SparseDouble] -> ShowS
show :: SparseDouble -> String
$cshow :: SparseDouble -> String
showsPrec :: Key -> SparseDouble -> ShowS
$cshowsPrec :: Key -> SparseDouble -> ShowS
Show, Typeable SparseDouble
SparseDouble -> DataType
SparseDouble -> Constr
(forall b. Data b => b -> b) -> SparseDouble -> SparseDouble
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Key -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Key -> (forall d. Data d => d -> u) -> SparseDouble -> u
forall u. (forall d. Data d => d -> u) -> SparseDouble -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c SparseDouble)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c SparseDouble)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
$cgmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
$cgmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
$cgmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> SparseDouble -> m SparseDouble
gmapQi :: forall u. Key -> (forall d. Data d => d -> u) -> SparseDouble -> u
$cgmapQi :: forall u. Key -> (forall d. Data d => d -> u) -> SparseDouble -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> SparseDouble -> [u]
$cgmapQ :: forall u. (forall d. Data d => d -> u) -> SparseDouble -> [u]
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
$cgmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
$cgmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> SparseDouble -> r
gmapT :: (forall b. Data b => b -> b) -> SparseDouble -> SparseDouble
$cgmapT :: (forall b. Data b => b -> b) -> SparseDouble -> SparseDouble
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c SparseDouble)
$cdataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c SparseDouble)
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c SparseDouble)
$cdataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c SparseDouble)
dataTypeOf :: SparseDouble -> DataType
$cdataTypeOf :: SparseDouble -> DataType
toConstr :: SparseDouble -> Constr
$ctoConstr :: SparseDouble -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
$cgunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c SparseDouble
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
$cgfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> SparseDouble -> c SparseDouble
Data, Typeable)

vars :: Traversable f => f Double -> f SparseDouble
vars :: forall (f :: * -> *). Traversable f => f Double -> f SparseDouble
vars = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Key -> Double -> (Key, SparseDouble)
var Key
0 where
  var :: Key -> Double -> (Key, SparseDouble)
var !Key
n Double
a = (Key
n forall a. Num a => a -> a -> a
+ Key
1, Double -> IntMap SparseDouble -> SparseDouble
Sparse Double
a forall a b. (a -> b) -> a -> b
$ forall a. Key -> a -> IntMap a
singleton Key
n forall a b. (a -> b) -> a -> b
$ forall t. Mode t => Scalar t -> t
auto Double
1)
{-# INLINE vars #-}

apply :: Traversable f => (f SparseDouble -> b) -> f Double -> b
apply :: forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> b
f = f SparseDouble -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *). Traversable f => f Double -> f SparseDouble
vars
{-# INLINE apply #-}

d :: Traversable f => f b -> SparseDouble -> f Double
d :: forall (f :: * -> *) b.
Traversable f =>
f b -> SparseDouble -> f Double
d f b
fs SparseDouble
Zero = Double
0 forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs
d f b
fs (Sparse Double
_ IntMap SparseDouble
da) = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ !Key
n b
_ -> (Key
n forall a. Num a => a -> a -> a
+ Key
1, forall b a. b -> (a -> b) -> Maybe a -> b
maybe Double
0 SparseDouble -> Double
primal forall a b. (a -> b) -> a -> b
$ forall a. Key -> IntMap a -> Maybe a
lookup Key
n IntMap SparseDouble
da)) Key
0 f b
fs
{-# INLINE d #-}

d' :: Traversable f => f Double -> SparseDouble -> (Double, f Double)
d' :: forall (f :: * -> *).
Traversable f =>
f Double -> SparseDouble -> (Double, f Double)
d' f Double
fs SparseDouble
Zero = (Double
0, Double
0 forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f Double
fs)
d' f Double
fs (Sparse Double
a IntMap SparseDouble
da) = (Double
a, forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ !Key
n Double
_ -> (Key
n forall a. Num a => a -> a -> a
+ Key
1, forall b a. b -> (a -> b) -> Maybe a -> b
maybe Double
0 SparseDouble -> Double
primal forall a b. (a -> b) -> a -> b
$ forall a. Key -> IntMap a -> Maybe a
lookup Key
n IntMap SparseDouble
da)) Key
0 f Double
fs)
{-# INLINE d' #-}

ds :: Traversable f => f b -> SparseDouble -> Cofree f Double
ds :: forall (f :: * -> *) b.
Traversable f =>
f b -> SparseDouble -> Cofree f Double
ds f b
fs SparseDouble
Zero = Cofree f Double
r where r :: Cofree f Double
r = Double
0 forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Cofree f Double
r forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs)
ds f b
fs as :: SparseDouble
as@(Sparse Double
a IntMap SparseDouble
_) = Double
a forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Key -> Cofree f Double
go Monomial
emptyMonomial forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Key
fns) where
  fns :: f Key
fns = forall (f :: * -> *) a. Traversable f => f a -> f Key
skeleton f b
fs
  -- go :: Monomial -> Int -> Cofree f a
  go :: Monomial -> Key -> Cofree f Double
go Monomial
ix Key
i = [Key] -> SparseDouble -> Double
partial (Monomial -> [Key]
indices Monomial
ix') SparseDouble
as forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Key -> Cofree f Double
go Monomial
ix' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Key
fns) where
    ix' :: Monomial
ix' = Key -> Monomial -> Monomial
addToMonomial Key
i Monomial
ix
{-# INLINE ds #-}

partialS :: [Int] -> SparseDouble -> SparseDouble
partialS :: [Key] -> SparseDouble -> SparseDouble
partialS []     SparseDouble
a             = SparseDouble
a
partialS (Key
n:[Key]
ns) (Sparse Double
_ IntMap SparseDouble
da) = [Key] -> SparseDouble -> SparseDouble
partialS [Key]
ns forall a b. (a -> b) -> a -> b
$ forall a. a -> Key -> IntMap a -> a
findWithDefault SparseDouble
Zero Key
n IntMap SparseDouble
da
partialS [Key]
_      SparseDouble
Zero          = SparseDouble
Zero
{-# INLINE partialS #-}

partial :: [Int] -> SparseDouble -> Double
partial :: [Key] -> SparseDouble -> Double
partial []     (Sparse Double
a IntMap SparseDouble
_)  = Double
a
partial (Key
n:[Key]
ns) (Sparse Double
_ IntMap SparseDouble
da) = [Key] -> SparseDouble -> Double
partial [Key]
ns forall a b. (a -> b) -> a -> b
$ forall a. a -> Key -> IntMap a -> a
findWithDefault (forall t. Mode t => Scalar t -> t
auto Double
0) Key
n IntMap SparseDouble
da
partial [Key]
_      SparseDouble
Zero          = Double
0
{-# INLINE partial #-}

spartial :: [Int] -> SparseDouble -> Maybe Double
spartial :: [Key] -> SparseDouble -> Maybe Double
spartial [] (Sparse Double
a IntMap SparseDouble
_) = forall a. a -> Maybe a
Just Double
a
spartial (Key
n:[Key]
ns) (Sparse Double
_ IntMap SparseDouble
da) = do
  SparseDouble
a' <- forall a. Key -> IntMap a -> Maybe a
lookup Key
n IntMap SparseDouble
da
  [Key] -> SparseDouble -> Maybe Double
spartial [Key]
ns SparseDouble
a'
spartial [Key]
_  SparseDouble
Zero         = forall a. Maybe a
Nothing
{-# INLINE spartial #-}

primal :: SparseDouble -> Double
primal :: SparseDouble -> Double
primal (Sparse Double
a IntMap SparseDouble
_) = Double
a
primal SparseDouble
Zero = Double
0

instance Mode SparseDouble where
  type Scalar SparseDouble = Double

  auto :: Scalar SparseDouble -> SparseDouble
auto Scalar SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse Scalar SparseDouble
a forall a. IntMap a
IntMap.empty

  zero :: SparseDouble
zero = SparseDouble
Zero

  isKnownZero :: SparseDouble -> Bool
isKnownZero SparseDouble
Zero = Bool
True
  isKnownZero (Sparse Double
0 IntMap SparseDouble
m) = forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap SparseDouble
m
  isKnownZero SparseDouble
_ = Bool
False

  isKnownConstant :: SparseDouble -> Bool
isKnownConstant SparseDouble
Zero = Bool
True
  isKnownConstant (Sparse Double
_ IntMap SparseDouble
m) = forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap SparseDouble
m

  asKnownConstant :: SparseDouble -> Maybe (Scalar SparseDouble)
asKnownConstant SparseDouble
Zero = forall a. a -> Maybe a
Just Double
0
  asKnownConstant (Sparse Double
a IntMap SparseDouble
m) = Double
a forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (f :: * -> *). Alternative f => Bool -> f ()
guard (forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap SparseDouble
m)

  SparseDouble
Zero        ^* :: SparseDouble -> Scalar SparseDouble -> SparseDouble
^* Scalar SparseDouble
_ = SparseDouble
Zero
  Sparse Double
a IntMap SparseDouble
as ^* Scalar SparseDouble
b = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
a forall a. Num a => a -> a -> a
* Scalar SparseDouble
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall t. Mode t => t -> Scalar t -> t
^* Scalar SparseDouble
b) IntMap SparseDouble
as
  Scalar SparseDouble
_ *^ :: Scalar SparseDouble -> SparseDouble -> SparseDouble
*^ SparseDouble
Zero        = SparseDouble
Zero
  Scalar SparseDouble
a *^ Sparse Double
b IntMap SparseDouble
bs = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble
a forall a. Num a => a -> a -> a
* Double
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar SparseDouble
a forall t. Mode t => Scalar t -> t -> t
*^) IntMap SparseDouble
bs

  SparseDouble
Zero        ^/ :: Fractional (Scalar SparseDouble) =>
SparseDouble -> Scalar SparseDouble -> SparseDouble
^/ Scalar SparseDouble
_ = SparseDouble
Zero
  Sparse Double
a IntMap SparseDouble
as ^/ Scalar SparseDouble
b = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
a forall a. Fractional a => a -> a -> a
/ Scalar SparseDouble
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall t. (Mode t, Fractional (Scalar t)) => t -> Scalar t -> t
^/ Scalar SparseDouble
b) IntMap SparseDouble
as

infixr 6 <+>

(<+>) :: SparseDouble -> SparseDouble -> SparseDouble
SparseDouble
Zero <+> :: SparseDouble -> SparseDouble -> SparseDouble
<+> SparseDouble
a = SparseDouble
a
SparseDouble
a <+> SparseDouble
Zero = SparseDouble
a
Sparse Double
a IntMap SparseDouble
as <+> Sparse Double
b IntMap SparseDouble
bs = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Double
a forall a. Num a => a -> a -> a
+ Double
b) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>) IntMap SparseDouble
as IntMap SparseDouble
bs

-- The instances for Jacobian for Sparse and Tower are almost identical;
-- could easily be made exactly equal by small changes.
instance Jacobian SparseDouble where
  type D SparseDouble = SparseDouble
  unary :: (Scalar SparseDouble -> Scalar SparseDouble)
-> D SparseDouble -> SparseDouble -> SparseDouble
unary Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
_ SparseDouble
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble
f Double
0)
  unary Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
dadb (Sparse Double
pb IntMap SparseDouble
bs) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble
f Double
pb) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* D SparseDouble
dadb) IntMap SparseDouble
bs

  lift1 :: (Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble -> D SparseDouble)
-> SparseDouble
-> SparseDouble
lift1 Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble
_ SparseDouble
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble
f Double
0)
  lift1 Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
bs) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble
f Double
pb) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* D SparseDouble -> D SparseDouble
df SparseDouble
b) IntMap SparseDouble
bs

  lift1_ :: (Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble -> D SparseDouble -> D SparseDouble)
-> SparseDouble
-> SparseDouble
lift1_ Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble -> D SparseDouble
_  SparseDouble
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble
f Double
0)
  lift1_ Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble -> D SparseDouble -> D SparseDouble
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
bs) = SparseDouble
a where
    a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble
f Double
pb) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble -> D SparseDouble -> D SparseDouble
df SparseDouble
a SparseDouble
b forall a. Num a => a -> a -> a
*) IntMap SparseDouble
bs

  binary :: (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble)
-> D SparseDouble
-> D SparseDouble
-> SparseDouble
-> SparseDouble
-> SparseDouble
binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
_    D SparseDouble
_    SparseDouble
Zero           SparseDouble
Zero           = forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
0 Double
0)
  binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
_    D SparseDouble
dadc SparseDouble
Zero           (Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
0  Double
pc) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
dadc forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc
  binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
dadb D SparseDouble
_    (Sparse Double
pb IntMap SparseDouble
db) SparseDouble
Zero           = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
pb Double
0 ) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
dadb forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db
  binary Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
dadb D SparseDouble
dadc (Sparse Double
pb IntMap SparseDouble
db) (Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
pb Double
pc) forall a b. (a -> b) -> a -> b
$
    forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>)  (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
dadb forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
dadc forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc)

  lift2 :: (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble
    -> D SparseDouble -> (D SparseDouble, D SparseDouble))
-> SparseDouble
-> SparseDouble
-> SparseDouble
lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
_  SparseDouble
Zero             SparseDouble
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
0 Double
0)
  lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df SparseDouble
Zero c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
0 Double
pc) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadc forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc where dadc :: SparseDouble
dadc = forall a b. (a, b) -> b
snd (D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df forall t. Mode t => t
zero SparseDouble
c)
  lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) SparseDouble
Zero = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
pb Double
0) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* SparseDouble
dadb) IntMap SparseDouble
db where dadb :: SparseDouble
dadb = forall a b. (a, b) -> a
fst (D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df SparseDouble
b forall t. Mode t => t
zero)
  lift2 Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
pb Double
pc) IntMap SparseDouble
da where
    (D SparseDouble
dadb, D SparseDouble
dadc) = D SparseDouble
-> D SparseDouble -> (D SparseDouble, D SparseDouble)
df SparseDouble
b SparseDouble
c
    da :: IntMap SparseDouble
da = forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
dadb forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D SparseDouble
dadc forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc)

  lift2_ :: (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble)
-> (D SparseDouble
    -> D SparseDouble
    -> D SparseDouble
    -> (D SparseDouble, D SparseDouble))
-> SparseDouble
-> SparseDouble
-> SparseDouble
lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
_  SparseDouble
Zero             SparseDouble
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
0 Double
0)
  lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) SparseDouble
Zero = SparseDouble
a where a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
pb Double
0) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a b. (a, b) -> a
fst (D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df SparseDouble
a SparseDouble
b forall t. Mode t => t
zero) forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db)
  lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df SparseDouble
Zero c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = SparseDouble
a where a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
0 Double
pc) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df SparseDouble
a forall t. Mode t => t
zero SparseDouble
c)) IntMap SparseDouble
dc)
  lift2_ Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df b :: SparseDouble
b@(Sparse Double
pb IntMap SparseDouble
db) c :: SparseDouble
c@(Sparse Double
pc IntMap SparseDouble
dc) = SparseDouble
a where
    (SparseDouble
dadb, SparseDouble
dadc) = D SparseDouble
-> D SparseDouble
-> D SparseDouble
-> (D SparseDouble, D SparseDouble)
df SparseDouble
a SparseDouble
b SparseDouble
c
    a :: SparseDouble
a = Double -> IntMap SparseDouble -> SparseDouble
Sparse (Scalar SparseDouble -> Scalar SparseDouble -> Scalar SparseDouble
f Double
pb Double
pc) IntMap SparseDouble
da
    da :: IntMap SparseDouble
da = forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith SparseDouble -> SparseDouble -> SparseDouble
(<+>) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadb forall a. Num a => a -> a -> a
*) IntMap SparseDouble
db) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (SparseDouble
dadc forall a. Num a => a -> a -> a
*) IntMap SparseDouble
dc)

#define HEAD SparseDouble
#define BODY1(x)
#define BODY2(x,y)
#define NO_Bounded
#include "instances.h"

class Grad i o o' | i -> o o', o -> i o', o' -> i o where
  pack :: i -> [SparseDouble] -> SparseDouble
  unpack :: ([Double] -> [Double]) -> o
  unpack' :: ([Double] -> (Double, [Double])) -> o'

instance Grad SparseDouble [Double] (Double, [Double]) where
  pack :: SparseDouble -> [SparseDouble] -> SparseDouble
pack SparseDouble
i [SparseDouble]
_ = SparseDouble
i
  unpack :: ([Double] -> [Double]) -> [Double]
unpack [Double] -> [Double]
f = [Double] -> [Double]
f []
  unpack' :: ([Double] -> (Double, [Double])) -> (Double, [Double])
unpack' [Double] -> (Double, [Double])
f = [Double] -> (Double, [Double])
f []

instance Grad i o o' => Grad (SparseDouble -> i) (Double -> o) (Double -> o') where
  pack :: (SparseDouble -> i) -> [SparseDouble] -> SparseDouble
pack SparseDouble -> i
f (SparseDouble
a:[SparseDouble]
as) = forall i o o'. Grad i o o' => i -> [SparseDouble] -> SparseDouble
pack (SparseDouble -> i
f SparseDouble
a) [SparseDouble]
as
  pack SparseDouble -> i
_ [] = forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpack :: ([Double] -> [Double]) -> Double -> o
unpack [Double] -> [Double]
f Double
a = forall i o o'. Grad i o o' => ([Double] -> [Double]) -> o
unpack ([Double] -> [Double]
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
aforall a. a -> [a] -> [a]
:))
  unpack' :: ([Double] -> (Double, [Double])) -> Double -> o'
unpack' [Double] -> (Double, [Double])
f Double
a = forall i o o'.
Grad i o o' =>
([Double] -> (Double, [Double])) -> o'
unpack' ([Double] -> (Double, [Double])
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
aforall a. a -> [a] -> [a]
:))

vgrad :: Grad i o o' => i -> o
vgrad :: forall i o o'. Grad i o o' => i -> o
vgrad i
i = forall i o o'. Grad i o o' => ([Double] -> [Double]) -> o
unpack (forall {f :: * -> *}.
Traversable f =>
(f SparseDouble -> SparseDouble) -> f Double -> f Double
unsafeGrad (forall i o o'. Grad i o o' => i -> [SparseDouble] -> SparseDouble
pack i
i)) where
  unsafeGrad :: (f SparseDouble -> SparseDouble) -> f Double -> f Double
unsafeGrad f SparseDouble -> SparseDouble
f f Double
as = forall (f :: * -> *) b.
Traversable f =>
f b -> SparseDouble -> f Double
d f Double
as forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> SparseDouble
f f Double
as
{-# INLINE vgrad #-}

vgrad' :: Grad i o o' => i -> o'
vgrad' :: forall i o o'. Grad i o o' => i -> o'
vgrad' i
i = forall i o o'.
Grad i o o' =>
([Double] -> (Double, [Double])) -> o'
unpack' (forall {f :: * -> *}.
Traversable f =>
(f SparseDouble -> SparseDouble) -> f Double -> (Double, f Double)
unsafeGrad' (forall i o o'. Grad i o o' => i -> [SparseDouble] -> SparseDouble
pack i
i)) where
  unsafeGrad' :: (f SparseDouble -> SparseDouble) -> f Double -> (Double, f Double)
unsafeGrad' f SparseDouble -> SparseDouble
f f Double
as = forall (f :: * -> *).
Traversable f =>
f Double -> SparseDouble -> (Double, f Double)
d' f Double
as forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> SparseDouble
f f Double
as
{-# INLINE vgrad' #-}

class Grads i o | i -> o, o -> i where
  packs :: i -> [SparseDouble] -> SparseDouble
  unpacks :: ([Double] -> Cofree [] Double) -> o

instance Grads SparseDouble (Cofree [] Double) where
  packs :: SparseDouble -> [SparseDouble] -> SparseDouble
packs SparseDouble
i [SparseDouble]
_ = SparseDouble
i
  unpacks :: ([Double] -> Cofree [] Double) -> Cofree [] Double
unpacks [Double] -> Cofree [] Double
f = [Double] -> Cofree [] Double
f []

instance Grads i o => Grads (SparseDouble -> i) (Double -> o) where
  packs :: (SparseDouble -> i) -> [SparseDouble] -> SparseDouble
packs SparseDouble -> i
f (SparseDouble
a:[SparseDouble]
as) = forall i o. Grads i o => i -> [SparseDouble] -> SparseDouble
packs (SparseDouble -> i
f SparseDouble
a) [SparseDouble]
as
  packs SparseDouble -> i
_ [] = forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpacks :: ([Double] -> Cofree [] Double) -> Double -> o
unpacks [Double] -> Cofree [] Double
f Double
a = forall i o. Grads i o => ([Double] -> Cofree [] Double) -> o
unpacks ([Double] -> Cofree [] Double
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
aforall a. a -> [a] -> [a]
:))

vgrads :: Grads i o => i -> o
vgrads :: forall i o. Grads i o => i -> o
vgrads i
i = forall i o. Grads i o => ([Double] -> Cofree [] Double) -> o
unpacks (forall {f :: * -> *}.
Traversable f =>
(f SparseDouble -> SparseDouble) -> f Double -> Cofree f Double
unsafeGrads (forall i o. Grads i o => i -> [SparseDouble] -> SparseDouble
packs i
i)) where
  unsafeGrads :: (f SparseDouble -> SparseDouble) -> f Double -> Cofree f Double
unsafeGrads f SparseDouble -> SparseDouble
f f Double
as = forall (f :: * -> *) b.
Traversable f =>
f b -> SparseDouble -> Cofree f Double
ds f Double
as forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) b.
Traversable f =>
(f SparseDouble -> b) -> f Double -> b
apply f SparseDouble -> SparseDouble
f f Double
as
{-# INLINE vgrads #-}

isZero :: SparseDouble -> Bool
isZero :: SparseDouble -> Bool
isZero SparseDouble
Zero = Bool
True
isZero SparseDouble
_ = Bool
False

mul :: SparseDouble -> SparseDouble -> SparseDouble
mul :: SparseDouble -> SparseDouble -> SparseDouble
mul SparseDouble
Zero SparseDouble
_ = SparseDouble
Zero
mul SparseDouble
_ SparseDouble
Zero = SparseDouble
Zero
mul f :: SparseDouble
f@(Sparse Double
_ IntMap SparseDouble
am) g :: SparseDouble
g@(Sparse Double
_ IntMap SparseDouble
bm) = Double -> IntMap SparseDouble -> SparseDouble
Sparse (SparseDouble -> Double
primal SparseDouble
f forall a. Num a => a -> a -> a
* SparseDouble -> Double
primal SparseDouble
g) (Key -> Monomial -> IntMap SparseDouble
derivs Key
0 Monomial
emptyMonomial) where
  derivs :: Key -> Monomial -> IntMap SparseDouble
derivs Key
v Monomial
mi = forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IntMap.unions (forall a b. (a -> b) -> [a] -> [b]
map Key -> IntMap SparseDouble
fn [Key
v..Key
kMax]) where
    fn :: Key -> IntMap SparseDouble
fn Key
w
      | forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
zs = forall a. IntMap a
IntMap.empty
      | Bool
otherwise = forall a. Key -> a -> IntMap a
IntMap.singleton Key
w (Double -> IntMap SparseDouble -> SparseDouble
Sparse (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Double]
ds) (Key -> Monomial -> IntMap SparseDouble
derivs Key
w Monomial
mi'))
      where
        mi' :: Monomial
mi' = Key -> Monomial -> Monomial
addToMonomial Key
w Monomial
mi
        ([Bool]
zs,[Double]
ds) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a b. (a -> b) -> [a] -> [b]
map (Integer, Monomial, Monomial) -> (Bool, Double)
derVal (Monomial -> [(Integer, Monomial, Monomial)]
terms Monomial
mi'))
        derVal :: (Integer, Monomial, Monomial) -> (Bool, Double)
derVal (Integer
bin,Monomial
mif,Monomial
mig) = (SparseDouble -> Bool
isZero SparseDouble
fder Bool -> Bool -> Bool
|| SparseDouble -> Bool
isZero SparseDouble
gder, forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
bin forall a. Num a => a -> a -> a
* SparseDouble -> Double
primal SparseDouble
fder forall a. Num a => a -> a -> a
* SparseDouble -> Double
primal SparseDouble
gder) where
          fder :: SparseDouble
fder = [Key] -> SparseDouble -> SparseDouble
partialS (Monomial -> [Key]
indices Monomial
mif) SparseDouble
f
          gder :: SparseDouble
gder = [Key] -> SparseDouble -> SparseDouble
partialS (Monomial -> [Key]
indices Monomial
mig) SparseDouble
g
  kMax :: Key
kMax = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Key
1) (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) (forall a. IntMap a -> Maybe ((Key, a), IntMap a)
IntMap.maxViewWithKey IntMap SparseDouble
am) forall a. Ord a => a -> a -> a
`max` forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Key
1) (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) (forall a. IntMap a -> Maybe ((Key, a), IntMap a)
IntMap.maxViewWithKey IntMap SparseDouble
bm)