{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------

module Numeric.AD.Mode
  (
  -- * AD modes
    Mode(..)
  , pattern KnownZero
  , pattern Auto
  ) where

import Numeric.Natural
import Data.Complex
import Data.Int
import Data.Ratio
import Data.Word

infixr 7 *^
infixl 7 ^*
infixr 7 ^/

class (Num t, Num (Scalar t)) => Mode t where
  type Scalar t
  type Scalar t = t

  -- | allowed to return False for items with a zero derivative, but we'll give more NaNs than strictly necessary
  isKnownConstant :: t -> Bool
  isKnownConstant t
_ = Bool
False

  asKnownConstant :: t -> Maybe (Scalar t)
  asKnownConstant t
_ = Maybe (Scalar t)
forall a. Maybe a
Nothing

  -- | allowed to return False for zero, but we give more NaN's than strictly necessary
  isKnownZero :: t -> Bool
  isKnownZero t
_ = Bool
False

  -- | Embed a constant
  auto  :: Scalar t -> t
  default auto :: (Scalar t ~ t) => Scalar t -> t
  auto = Scalar t -> t
forall a. a -> a
id

  -- | Scalar-vector multiplication
  (*^) :: Scalar t -> t -> t
  Scalar t
a *^ t
b = Scalar t -> t
forall t. Mode t => Scalar t -> t
auto Scalar t
a t -> t -> t
forall a. Num a => a -> a -> a
* t
b

  -- | Vector-scalar multiplication
  (^*) :: t -> Scalar t -> t
  t
a ^* Scalar t
b = t
a t -> t -> t
forall a. Num a => a -> a -> a
* Scalar t -> t
forall t. Mode t => Scalar t -> t
auto Scalar t
b

  -- | Scalar division
  (^/) :: Fractional (Scalar t) => t -> Scalar t -> t
  t
a ^/ Scalar t
b = t
a t -> Scalar t -> t
forall t. Mode t => t -> Scalar t -> t
^* Scalar t -> Scalar t
forall a. Fractional a => a -> a
recip Scalar t
b

  -- |
  -- @'zero' = 'lift' 0@
  zero :: t
  zero = Scalar t -> t
forall t. Mode t => Scalar t -> t
auto Scalar t
0

pattern KnownZero :: Mode s => s
pattern $bKnownZero :: s
$mKnownZero :: forall r s. Mode s => s -> (Void# -> r) -> (Void# -> r) -> r
KnownZero <- (isKnownZero -> True) where
  KnownZero = s
forall t. Mode t => t
zero

pattern Auto :: Mode s => Scalar s -> s
pattern $bAuto :: Scalar s -> s
$mAuto :: forall r s. Mode s => s -> (Scalar s -> r) -> (Void# -> r) -> r
Auto n <- (asKnownConstant -> Just n) where
  Auto Scalar s
n = Scalar s -> s
forall t. Mode t => Scalar t -> t
auto Scalar s
n

instance Mode Double where
  isKnownConstant :: Double -> Bool
isKnownConstant Double
_ = Bool
True
  asKnownConstant :: Double -> Maybe (Scalar Double)
asKnownConstant = Double -> Maybe (Scalar Double)
forall a. a -> Maybe a
Just
  isKnownZero :: Double -> Bool
isKnownZero Double
x = Double
0 Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
x
  ^/ :: Double -> Scalar Double -> Double
(^/) = Double -> Scalar Double -> Double
forall a. Fractional a => a -> a -> a
(/)

instance Mode Float where
  isKnownConstant :: Float -> Bool
isKnownConstant Float
_ = Bool
True
  asKnownConstant :: Float -> Maybe (Scalar Float)
asKnownConstant = Float -> Maybe (Scalar Float)
forall a. a -> Maybe a
Just
  isKnownZero :: Float -> Bool
isKnownZero Float
x = Float
0 Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
x
  ^/ :: Float -> Scalar Float -> Float
(^/) = Float -> Scalar Float -> Float
forall a. Fractional a => a -> a -> a
(/)

instance Mode Int where
  isKnownConstant :: Int -> Bool
isKnownConstant Int
_ = Bool
True
  asKnownConstant :: Int -> Maybe (Scalar Int)
asKnownConstant = Int -> Maybe (Scalar Int)
forall a. a -> Maybe a
Just
  isKnownZero :: Int -> Bool
isKnownZero Int
x = Int
0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x
  ^/ :: Int -> Scalar Int -> Int
(^/) = Int -> Scalar Int -> Int
forall a. Fractional a => a -> a -> a
(/)

instance Mode Integer where
  isKnownConstant :: Integer -> Bool
isKnownConstant Integer
_ = Bool
True
  asKnownConstant :: Integer -> Maybe (Scalar Integer)
asKnownConstant = Integer -> Maybe (Scalar Integer)
forall a. a -> Maybe a
Just
  isKnownZero :: Integer -> Bool
isKnownZero Integer
x = Integer
0 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
x
  ^/ :: Integer -> Scalar Integer -> Integer
(^/) = Integer -> Scalar Integer -> Integer
forall a. Fractional a => a -> a -> a
(/)

instance Mode Int8 where
  isKnownConstant :: Int8 -> Bool
isKnownConstant Int8
_ = Bool
True
  asKnownConstant :: Int8 -> Maybe (Scalar Int8)
asKnownConstant = Int8 -> Maybe (Scalar Int8)
forall a. a -> Maybe a
Just
  isKnownZero :: Int8 -> Bool
isKnownZero Int8
x = Int8
0 Int8 -> Int8 -> Bool
forall a. Eq a => a -> a -> Bool
== Int8
x
  ^/ :: Int8 -> Scalar Int8 -> Int8
(^/) = Int8 -> Scalar Int8 -> Int8
forall a. Fractional a => a -> a -> a
(/)

instance Mode Int16 where
  isKnownConstant :: Int16 -> Bool
isKnownConstant Int16
_ = Bool
True
  asKnownConstant :: Int16 -> Maybe (Scalar Int16)
asKnownConstant = Int16 -> Maybe (Scalar Int16)
forall a. a -> Maybe a
Just
  isKnownZero :: Int16 -> Bool
isKnownZero Int16
x = Int16
0 Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
== Int16
x
  ^/ :: Int16 -> Scalar Int16 -> Int16
(^/) = Int16 -> Scalar Int16 -> Int16
forall a. Fractional a => a -> a -> a
(/)

instance Mode Int32 where
  isKnownConstant :: Int32 -> Bool
isKnownConstant Int32
_ = Bool
True
  asKnownConstant :: Int32 -> Maybe (Scalar Int32)
asKnownConstant = Int32 -> Maybe (Scalar Int32)
forall a. a -> Maybe a
Just
  isKnownZero :: Int32 -> Bool
isKnownZero Int32
x = Int32
0 Int32 -> Int32 -> Bool
forall a. Eq a => a -> a -> Bool
== Int32
x
  ^/ :: Int32 -> Scalar Int32 -> Int32
(^/) = Int32 -> Scalar Int32 -> Int32
forall a. Fractional a => a -> a -> a
(/)

instance Mode Int64 where
  isKnownConstant :: Int64 -> Bool
isKnownConstant Int64
_ = Bool
True
  asKnownConstant :: Int64 -> Maybe (Scalar Int64)
asKnownConstant = Int64 -> Maybe (Scalar Int64)
forall a. a -> Maybe a
Just
  isKnownZero :: Int64 -> Bool
isKnownZero Int64
x = Int64
0 Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
x
  ^/ :: Int64 -> Scalar Int64 -> Int64
(^/) = Int64 -> Scalar Int64 -> Int64
forall a. Fractional a => a -> a -> a
(/)

instance Mode Natural where
  isKnownConstant :: Natural -> Bool
isKnownConstant Natural
_ = Bool
True
  asKnownConstant :: Natural -> Maybe (Scalar Natural)
asKnownConstant = Natural -> Maybe (Scalar Natural)
forall a. a -> Maybe a
Just
  isKnownZero :: Natural -> Bool
isKnownZero Natural
x = Natural
0 Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
x
  ^/ :: Natural -> Scalar Natural -> Natural
(^/) = Natural -> Scalar Natural -> Natural
forall a. Fractional a => a -> a -> a
(/)

instance Mode Word where
  isKnownConstant :: Word -> Bool
isKnownConstant Word
_ = Bool
True
  asKnownConstant :: Word -> Maybe (Scalar Word)
asKnownConstant = Word -> Maybe (Scalar Word)
forall a. a -> Maybe a
Just
  isKnownZero :: Word -> Bool
isKnownZero Word
x = Word
0 Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
x
  ^/ :: Word -> Scalar Word -> Word
(^/) = Word -> Scalar Word -> Word
forall a. Fractional a => a -> a -> a
(/)

instance Mode Word8 where
  isKnownConstant :: Word8 -> Bool
isKnownConstant Word8
_ = Bool
True
  asKnownConstant :: Word8 -> Maybe (Scalar Word8)
asKnownConstant = Word8 -> Maybe (Scalar Word8)
forall a. a -> Maybe a
Just
  isKnownZero :: Word8 -> Bool
isKnownZero Word8
x = Word8
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
x
  ^/ :: Word8 -> Scalar Word8 -> Word8
(^/) = Word8 -> Scalar Word8 -> Word8
forall a. Fractional a => a -> a -> a
(/)

instance Mode Word16 where
  isKnownConstant :: Word16 -> Bool
isKnownConstant Word16
_ = Bool
True
  asKnownConstant :: Word16 -> Maybe (Scalar Word16)
asKnownConstant = Word16 -> Maybe (Scalar Word16)
forall a. a -> Maybe a
Just
  isKnownZero :: Word16 -> Bool
isKnownZero Word16
x = Word16
0 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Word16
x
  ^/ :: Word16 -> Scalar Word16 -> Word16
(^/) = Word16 -> Scalar Word16 -> Word16
forall a. Fractional a => a -> a -> a
(/)

instance Mode Word32 where
  isKnownConstant :: Word32 -> Bool
isKnownConstant Word32
_ = Bool
True
  asKnownConstant :: Word32 -> Maybe (Scalar Word32)
asKnownConstant = Word32 -> Maybe (Scalar Word32)
forall a. a -> Maybe a
Just
  isKnownZero :: Word32 -> Bool
isKnownZero Word32
x = Word32
0 Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
x
  ^/ :: Word32 -> Scalar Word32 -> Word32
(^/) = Word32 -> Scalar Word32 -> Word32
forall a. Fractional a => a -> a -> a
(/)

instance Mode Word64 where
  isKnownConstant :: Word64 -> Bool
isKnownConstant Word64
_ = Bool
True
  asKnownConstant :: Word64 -> Maybe (Scalar Word64)
asKnownConstant = Word64 -> Maybe (Scalar Word64)
forall a. a -> Maybe a
Just
  isKnownZero :: Word64 -> Bool
isKnownZero Word64
x = Word64
0 Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
x
  ^/ :: Word64 -> Scalar Word64 -> Word64
(^/) = Word64 -> Scalar Word64 -> Word64
forall a. Fractional a => a -> a -> a
(/)

instance RealFloat a => Mode (Complex a) where
  isKnownConstant :: Complex a -> Bool
isKnownConstant Complex a
_ = Bool
True
  asKnownConstant :: Complex a -> Maybe (Scalar (Complex a))
asKnownConstant = Complex a -> Maybe (Scalar (Complex a))
forall a. a -> Maybe a
Just
  isKnownZero :: Complex a -> Bool
isKnownZero Complex a
x = Complex a
0 Complex a -> Complex a -> Bool
forall a. Eq a => a -> a -> Bool
== Complex a
x
  ^/ :: Complex a -> Scalar (Complex a) -> Complex a
(^/) = Complex a -> Scalar (Complex a) -> Complex a
forall a. Fractional a => a -> a -> a
(/)

instance Integral a => Mode (Ratio a) where
  isKnownConstant :: Ratio a -> Bool
isKnownConstant Ratio a
_ = Bool
True
  asKnownConstant :: Ratio a -> Maybe (Scalar (Ratio a))
asKnownConstant = Ratio a -> Maybe (Scalar (Ratio a))
forall a. a -> Maybe a
Just
  isKnownZero :: Ratio a -> Bool
isKnownZero Ratio a
x = Ratio a
0 Ratio a -> Ratio a -> Bool
forall a. Eq a => a -> a -> Bool
== Ratio a
x
  ^/ :: Ratio a -> Scalar (Ratio a) -> Ratio a
(^/) = Ratio a -> Scalar (Ratio a) -> Ratio a
forall a. Fractional a => a -> a -> a
(/)