{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Combinators used internally by @Numeric.AD@
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Combinators
  ( zipWithT
  , zipWithDefaultT
  , withPrimal
  , fromBy
  , takeWhileDifferent
  ) where

import Data.Traversable (mapAccumL)
import Data.Foldable (toList)
import Numeric.AD.Mode
import Numeric.AD.Jacobian

-- | Zip a @'Foldable' f@ with a @'Traversable' g@ assuming @f@ has at least as many entries as @g@.
zipWithT :: (Foldable f, Traversable g) => (a -> b -> c) -> f a -> g b -> g c
zipWithT :: (a -> b -> c) -> f a -> g b -> g c
zipWithT a -> b -> c
f f a
as = ([a], g c) -> g c
forall a b. (a, b) -> b
snd (([a], g c) -> g c) -> (g b -> ([a], g c)) -> g b -> g c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> b -> ([a], c)) -> [a] -> g b -> ([a], g c)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL [a] -> b -> ([a], c)
f' (f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList f a
as)
  where
    f' :: [a] -> b -> ([a], c)
f' (a
a:[a]
as') b
b = ([a]
as', a -> b -> c
f a
a b
b)
    f' []      b
_ = [Char] -> ([a], c)
forall a. HasCallStack => [Char] -> a
error [Char]
"zipWithT: second argument contains less entries than third argument"

-- | Zip a @'Foldable' f@ with a @'Traversable' g@ assuming @f@, using a default value after @f@ is exhausted.
zipWithDefaultT :: (Foldable f, Traversable g) => a -> (a -> b -> c) -> f a -> g b -> g c
zipWithDefaultT :: a -> (a -> b -> c) -> f a -> g b -> g c
zipWithDefaultT a
z a -> b -> c
f f a
as = (a -> b -> c) -> [a] -> g b -> g c
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> b -> c
f (f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList f a
as [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> [a]
forall a. a -> [a]
repeat a
z)

-- | Used internally to define various 'Enum' combinators.
withPrimal :: Jacobian t => t -> Scalar t -> t
withPrimal :: t -> Scalar t -> t
withPrimal t
t Scalar t
a = (Scalar t -> Scalar t) -> D t -> t -> t
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (Scalar t -> Scalar t -> Scalar t
forall a b. a -> b -> a
const Scalar t
a) D t
1 t
t
{-# INLINE withPrimal #-}

-- | Used internally to define various 'Enum' combinators.
fromBy :: Jacobian t => t -> t -> Int -> Scalar t -> t
fromBy :: t -> t -> Int -> Scalar t -> t
fromBy t
a t
delta Int
n Scalar t
x = (Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar t
_ Scalar t
_ -> Scalar t
x) D t
1 (Int -> D t
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) t
a t
delta

-- | Used internally to implement functions which truncate lists after the
-- stream of results converge
takeWhileDifferent :: Eq a => [a] -> [a]
takeWhileDifferent :: [a] -> [a]
takeWhileDifferent (a
x1:a
x2:[a]
xs) = if a
x1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x2
                                  then [a
x1]
                                  else a
x1 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent (a
x2a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs)
takeWhileDifferent [a]
xs = [a]
xs