{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Combinators used internally by @Numeric.AD@
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Combinators
  ( zipWithT
  , zipWithDefaultT
  , withPrimal
  , fromBy
  , takeWhileDifferent
  ) where

import Data.Traversable (mapAccumL)
import Data.Foldable (toList)
import Numeric.AD.Mode
import Numeric.AD.Jacobian

-- | Zip a @'Foldable' f@ with a @'Traversable' g@ assuming @f@ has at least as many entries as @g@.
zipWithT :: (Foldable f, Traversable g) => (a -> b -> c) -> f a -> g b -> g c
zipWithT :: forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> b -> c
f f a
as = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL [a] -> b -> ([a], c)
f' (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList f a
as)
  where
    f' :: [a] -> b -> ([a], c)
f' (a
a:[a]
as') b
b = ([a]
as', a -> b -> c
f a
a b
b)
    f' []      b
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"zipWithT: second argument contains less entries than third argument"

-- | Zip a @'Foldable' f@ with a @'Traversable' g@ assuming @f@, using a default value after @f@ is exhausted.
zipWithDefaultT :: (Foldable f, Traversable g) => a -> (a -> b -> c) -> f a -> g b -> g c
zipWithDefaultT :: forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
a -> (a -> b -> c) -> f a -> g b -> g c
zipWithDefaultT a
z a -> b -> c
f f a
as = forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> b -> c
f (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList f a
as forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat a
z)

-- | Used internally to define various 'Enum' combinators.
withPrimal :: Jacobian t => t -> Scalar t -> t
withPrimal :: forall t. Jacobian t => t -> Scalar t -> t
withPrimal t
t Scalar t
a = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (forall a b. a -> b -> a
const Scalar t
a) D t
1 t
t
{-# INLINE withPrimal #-}

-- | Used internally to define various 'Enum' combinators.
fromBy :: Jacobian t => t -> t -> Int -> Scalar t -> t
fromBy :: forall t. Jacobian t => t -> t -> Int -> Scalar t -> t
fromBy t
a t
delta Int
n Scalar t
x = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar t
_ Scalar t
_ -> Scalar t
x) D t
1 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) t
a t
delta

-- | Used internally to implement functions which truncate lists after the
-- stream of results converge
takeWhileDifferent :: Eq a => [a] -> [a]
takeWhileDifferent :: forall a. Eq a => [a] -> [a]
takeWhileDifferent (a
x1:a
x2:[a]
xs) = if a
x1 forall a. Eq a => a -> a -> Bool
== a
x2
                                  then [a
x1]
                                  else a
x1 forall a. a -> [a] -> [a]
: forall a. Eq a => [a] -> [a]
takeWhileDifferent (a
x2forall a. a -> [a] -> [a]
:[a]
xs)
takeWhileDifferent [a]
xs = [a]
xs