{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------
module Numeric.AD.Jet
  ( Jet(..)
  , headJet
  , tailJet
  , jet
  , unjet
  ) where

import Data.Functor.Rep
import Data.Typeable
import Control.Comonad.Cofree

infixr 3 :-

-- | A 'Jet' is a tower of all (higher order) partial derivatives of a function
--
-- At each step, a @'Jet' f@ is wrapped in another layer worth of @f@.
--
-- > a :- f a :- f (f a) :- f (f (f a)) :- ...
data Jet f a = a :- Jet f (f a)
  deriving Typeable

-- | Used to sidestep the need for UndecidableInstances.
newtype Showable = Showable (Int -> String -> String)

instance Show Showable where
  showsPrec :: Int -> Showable -> ShowS
showsPrec Int
d (Showable Int -> ShowS
f) = Int -> ShowS
f Int
d

showable :: Show a => a -> Showable
showable :: a -> Showable
showable a
a = (Int -> ShowS) -> Showable
Showable (Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
`showsPrec` a
a)

-- Polymorphic recursion precludes 'Data' in its current form, as no Data1 class exists
-- Polymorphic recursion also breaks 'show' for 'Jet'!
-- factor Show1 out of Lifted?
instance (Functor f, Show (f Showable), Show a) => Show (Jet f a) where
  showsPrec :: Int -> Jet f a -> ShowS
showsPrec Int
d (a
a :- Jet f (f a)
as) = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
3) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
    Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
4 a
a ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" :- " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Jet f (f Showable) -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
3 ((a -> Showable) -> f a -> f Showable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Showable
forall a. Show a => a -> Showable
showable (f a -> f Showable) -> Jet f (f a) -> Jet f (f Showable)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Jet f (f a)
as)

instance Functor f => Functor (Jet f) where
  fmap :: (a -> b) -> Jet f a -> Jet f b
fmap a -> b
f (a
a :- Jet f (f a)
as) = a -> b
f a
a b -> Jet f (f b) -> Jet f b
forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- (f a -> f b) -> Jet f (f a) -> Jet f (f b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) Jet f (f a)
as

instance Foldable f => Foldable (Jet f) where
  foldMap :: (a -> m) -> Jet f a -> m
foldMap a -> m
f (a
a :- Jet f (f a)
as) = a -> m
f a
a m -> m -> m
forall a. Monoid a => a -> a -> a
`mappend` (f a -> m) -> Jet f (f a) -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((a -> m) -> f a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f) Jet f (f a)
as

instance Traversable f => Traversable (Jet f) where
  traverse :: (a -> f b) -> Jet f a -> f (Jet f b)
traverse a -> f b
f (a
a :- Jet f (f a)
as) = b -> Jet f (f b) -> Jet f b
forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
(:-) (b -> Jet f (f b) -> Jet f b) -> f b -> f (Jet f (f b) -> Jet f b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
a f (Jet f (f b) -> Jet f b) -> f (Jet f (f b)) -> f (Jet f b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (f a -> f (f b)) -> Jet f (f a) -> f (Jet f (f b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((a -> f b) -> f a -> f (f b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) Jet f (f a)
as

-- | Take the tail of a 'Jet'.
tailJet :: Jet f a -> Jet f (f a)
tailJet :: Jet f a -> Jet f (f a)
tailJet (a
_ :- Jet f (f a)
as) = Jet f (f a)
as
{-# INLINE tailJet #-}

-- | Take the head of a 'Jet'.
headJet :: Jet f a -> a
headJet :: Jet f a -> a
headJet (a
a :- Jet f (f a)
_) = a
a
{-# INLINE headJet #-}

-- | Construct a 'Jet' by unzipping the layers of a 'Cofree' 'Comonad'.
jet :: Functor f => Cofree f a -> Jet f a
jet :: Cofree f a -> Jet f a
jet (a
a :< f (Cofree f a)
as) = a
a a -> Jet f (f a) -> Jet f a
forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- f (Jet f a) -> Jet f (f a)
forall (f :: * -> *) a. Functor f => f (Jet f a) -> Jet f (f a)
dist (Cofree f a -> Jet f a
forall (f :: * -> *) a. Functor f => Cofree f a -> Jet f a
jet (Cofree f a -> Jet f a) -> f (Cofree f a) -> f (Jet f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Cofree f a)
as) where
  dist :: Functor f => f (Jet f a) -> Jet f (f a)
  dist :: f (Jet f a) -> Jet f (f a)
dist f (Jet f a)
x = (Jet f a -> a
forall (f :: * -> *) a. Jet f a -> a
headJet (Jet f a -> a) -> f (Jet f a) -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Jet f a)
x) f a -> Jet f (f (f a)) -> Jet f (f a)
forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- f (Jet f (f a)) -> Jet f (f (f a))
forall (f :: * -> *) a. Functor f => f (Jet f a) -> Jet f (f a)
dist (Jet f a -> Jet f (f a)
forall (f :: * -> *) a. Jet f a -> Jet f (f a)
tailJet (Jet f a -> Jet f (f a)) -> f (Jet f a) -> f (Jet f (f a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Jet f a)
x)

unjet :: Representable f => Jet f a -> Cofree f a
unjet :: Jet f a -> Cofree f a
unjet (a
a :- Jet f (f a)
as) = a
a a -> f (Cofree f a) -> Cofree f a
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Jet f a -> Cofree f a
forall (f :: * -> *) a. Representable f => Jet f a -> Cofree f a
unjet (Jet f a -> Cofree f a) -> f (Jet f a) -> f (Cofree f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Jet f (f a) -> f (Jet f a)
forall (f :: * -> *) a.
Representable f =>
Jet f (f a) -> f (Jet f a)
undist Jet f (f a)
as) where
  undist :: Representable f => Jet f (f a) -> f (Jet f a)
  undist :: Jet f (f a) -> f (Jet f a)
undist (f a
fa :- Jet f (f (f a))
fas) = (Rep f -> Jet f a) -> f (Jet f a)
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate ((Rep f -> Jet f a) -> f (Jet f a))
-> (Rep f -> Jet f a) -> f (Jet f a)
forall a b. (a -> b) -> a -> b
$ \Rep f
i -> f a -> Rep f -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index f a
fa Rep f
i a -> Jet f (f a) -> Jet f a
forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- f (Jet f (f a)) -> Rep f -> Jet f (f a)
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index (Jet f (f (f a)) -> f (Jet f (f a))
forall (f :: * -> *) a.
Representable f =>
Jet f (f a) -> f (Jet f a)
undist Jet f (f (f a))
fas) Rep f
i