{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------
module Numeric.AD.Jet
  ( Jet(..)
  , headJet
  , tailJet
  , jet
  , unjet
  ) where

import Data.Functor.Rep
import Data.Typeable
import Control.Comonad.Cofree

infixr 3 :-

-- | A 'Jet' is a tower of all (higher order) partial derivatives of a function
--
-- At each step, a @'Jet' f@ is wrapped in another layer worth of @f@.
--
-- > a :- f a :- f (f a) :- f (f (f a)) :- ...
data Jet f a = a :- Jet f (f a)
  deriving Typeable

-- | Used to sidestep the need for UndecidableInstances.
newtype Showable = Showable (Int -> String -> String)

instance Show Showable where
  showsPrec :: Int -> Showable -> ShowS
showsPrec Int
d (Showable Int -> ShowS
f) = Int -> ShowS
f Int
d

showable :: Show a => a -> Showable
showable :: forall a. Show a => a -> Showable
showable a
a = (Int -> ShowS) -> Showable
Showable (forall a. Show a => Int -> a -> ShowS
`showsPrec` a
a)

-- Polymorphic recursion precludes 'Data' in its current form, as no Data1 class exists
-- Polymorphic recursion also breaks 'show' for 'Jet'!
-- factor Show1 out of Lifted?
instance (Functor f, Show (f Showable), Show a) => Show (Jet f a) where
  showsPrec :: Int -> Jet f a -> ShowS
showsPrec Int
d (a
a :- Jet f (f a)
as) = Bool -> ShowS -> ShowS
showParen (Int
d forall a. Ord a => a -> a -> Bool
> Int
3) forall a b. (a -> b) -> a -> b
$
    forall a. Show a => Int -> a -> ShowS
showsPrec Int
4 a
a forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" :- " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
3 (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Show a => a -> Showable
showable forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Jet f (f a)
as)

instance Functor f => Functor (Jet f) where
  fmap :: forall a b. (a -> b) -> Jet f a -> Jet f b
fmap a -> b
f (a
a :- Jet f (f a)
as) = a -> b
f a
a forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) Jet f (f a)
as

instance Foldable f => Foldable (Jet f) where
  foldMap :: forall m a. Monoid m => (a -> m) -> Jet f a -> m
foldMap a -> m
f (a
a :- Jet f (f a)
as) = a -> m
f a
a forall a. Monoid a => a -> a -> a
`mappend` forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f) Jet f (f a)
as

instance Traversable f => Traversable (Jet f) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Jet f a -> f (Jet f b)
traverse a -> f b
f (a
a :- Jet f (f a)
as) = forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
(:-) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
a forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) Jet f (f a)
as

-- | Take the tail of a 'Jet'.
tailJet :: Jet f a -> Jet f (f a)
tailJet :: forall (f :: * -> *) a. Jet f a -> Jet f (f a)
tailJet (a
_ :- Jet f (f a)
as) = Jet f (f a)
as
{-# INLINE tailJet #-}

-- | Take the head of a 'Jet'.
headJet :: Jet f a -> a
headJet :: forall (f :: * -> *) a. Jet f a -> a
headJet (a
a :- Jet f (f a)
_) = a
a
{-# INLINE headJet #-}

-- | Construct a 'Jet' by unzipping the layers of a 'Cofree' 'Comonad'.
jet :: Functor f => Cofree f a -> Jet f a
jet :: forall (f :: * -> *) a. Functor f => Cofree f a -> Jet f a
jet (a
a :< f (Cofree f a)
as) = a
a forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- forall (f :: * -> *) a. Functor f => f (Jet f a) -> Jet f (f a)
dist (forall (f :: * -> *) a. Functor f => Cofree f a -> Jet f a
jet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Cofree f a)
as) where
  dist :: Functor f => f (Jet f a) -> Jet f (f a)
  dist :: forall (f :: * -> *) a. Functor f => f (Jet f a) -> Jet f (f a)
dist f (Jet f a)
x = (forall (f :: * -> *) a. Jet f a -> a
headJet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Jet f a)
x) forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- forall (f :: * -> *) a. Functor f => f (Jet f a) -> Jet f (f a)
dist (forall (f :: * -> *) a. Jet f a -> Jet f (f a)
tailJet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Jet f a)
x)

unjet :: Representable f => Jet f a -> Cofree f a
unjet :: forall (f :: * -> *) a. Representable f => Jet f a -> Cofree f a
unjet (a
a :- Jet f (f a)
as) = a
a forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (forall (f :: * -> *) a. Representable f => Jet f a -> Cofree f a
unjet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a.
Representable f =>
Jet f (f a) -> f (Jet f a)
undist Jet f (f a)
as) where
  undist :: Representable f => Jet f (f a) -> f (Jet f a)
  undist :: forall (f :: * -> *) a.
Representable f =>
Jet f (f a) -> f (Jet f a)
undist (f a
fa :- Jet f (f (f a))
fas) = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate forall a b. (a -> b) -> a -> b
$ \Rep f
i -> forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index f a
fa Rep f
i forall (f :: * -> *) a. a -> Jet f (f a) -> Jet f a
:- forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index (forall (f :: * -> *) a.
Representable f =>
Jet f (f a) -> f (Jet f a)
undist Jet f (f (f a))
fas) Rep f
i