{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Numeric.Backprop.Tuple (
T0(..)
, T2(..)
, t2Tup, tupT2
, uncurryT2, curryT2
, t2_1, t2_2
, T3(..)
, t3Tup, tupT3
, t3_1, t3_2, t3_3
, uncurryT3, curryT3
, T(..)
, indexT
, tOnly, onlyT, tSplit, tAppend, tProd, prodT
, tIx, tHead, tTail, tTake, tDrop
, constT, mapT, zipT
) where
import Control.DeepSeq
import Data.Bifunctor
import Data.Data
import Data.Kind
import Data.Type.Combinator
import Data.Type.Index
import Data.Type.Length
import Data.Type.Product
import GHC.Generics (Generic)
import Lens.Micro
import Lens.Micro.Internal hiding (Index)
import Type.Class.Known
import Type.Family.List
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
data T0 = T0
deriving (Show, Read, Eq, Ord, Generic, Data)
instance NFData T0
data T2 a b = T2 !a !b
deriving (Show, Read, Eq, Ord, Generic, Functor, Data)
data T3 a b c = T3 !a !b !c
deriving (Show, Read, Eq, Ord, Generic, Functor, Data)
data T :: [Type] -> Type where
TNil :: T '[]
(:&) :: !a -> !(T as) -> T (a ': as)
instance (NFData a, NFData b) => NFData (T2 a b)
instance (NFData a, NFData b, NFData c) => NFData (T3 a b c)
instance ListC (NFData <$> as) => NFData (T as) where
rnf = \case
TNil -> ()
(!_) :& xs -> rnf xs
instance Bifunctor T2 where
bimap f g (T2 x y) = T2 (f x) (g y)
instance Bifunctor (T3 a) where
bimap f g (T3 x y z) = T3 x (f y) (g z)
t2Tup :: T2 a b -> (a, b)
t2Tup (T2 x y) = (x, y)
tupT2 :: (a, b) -> T2 a b
tupT2 (x, y) = T2 x y
t3Tup :: T3 a b c -> (a, b, c)
t3Tup (T3 x y z) = (x, y, z)
tupT3 :: (a, b, c) -> T3 a b c
tupT3 (x, y, z) = T3 x y z
onlyT :: a -> T '[a]
onlyT = (:& TNil)
tOnly :: T '[a] -> a
tOnly (x :& _) = x
uncurryT2 :: (a -> b -> c) -> T2 a b -> c
uncurryT2 f (T2 x y) = f x y
curryT2 :: (T2 a b -> c) -> a -> b -> c
curryT2 f x y = f (T2 x y)
uncurryT3 :: (a -> b -> c -> d) -> T3 a b c -> d
uncurryT3 f (T3 x y z) = f x y z
curryT3 :: (T3 a b c -> d) -> a -> b -> c -> d
curryT3 f x y z = f (T3 x y z)
instance Field1 (T2 a b) (T2 a' b) a a' where
_1 = t2_1
instance Field2 (T2 a b) (T2 a b') b b' where
_2 = t2_2
instance Field1 (T3 a b c) (T3 a' b c) a a' where
_1 = t3_1
instance Field2 (T3 a b c) (T3 a b' c) b b' where
_2 = t3_2
instance Field3 (T3 a b c) (T3 a b c') c c' where
_3 = t3_3
instance Field1 (T (a ': as)) (T (a ': as)) a a where
_1 = tIx IZ
instance Field2 (T (a ': b ': as)) (T (a ': b ': as)) b b where
_2 = tIx (IS IZ)
instance Field3 (T (a ': b ': c ': as)) (T (a ': b ': c ': as)) c c where
_3 = tIx (IS (IS IZ))
t2_1 :: Lens (T2 a b) (T2 a' b) a a'
t2_1 f (T2 x y) = (`T2` y) <$> f x
t2_2 :: Lens (T2 a b) (T2 a b') b b'
t2_2 f (T2 x y) = T2 x <$> f y
t3_1 :: Lens (T3 a b c) (T3 a' b c) a a'
t3_1 f (T3 x y z) = (\x' -> T3 x' y z) <$> f x
t3_2 :: Lens (T3 a b c) (T3 a b' c) b b'
t3_2 f (T3 x y z) = (\y' -> T3 x y' z) <$> f y
t3_3 :: Lens (T3 a b c) (T3 a b c') c c'
t3_3 f (T3 x y z) = T3 x y <$> f z
indexT :: Index as a -> T as -> a
indexT = flip (^.) . tIx
tIx :: Index as a -> Lens' (T as) a
tIx IZ f (x :& xs) = (:& xs) <$> f x
tIx (IS i) f (x :& xs) = (x :&) <$> tIx i f xs
tHead :: Lens (T (a ': as)) (T (b ': as)) a b
tHead f (x :& xs) = (:& xs) <$> f x
tTail :: Lens (T (a ': as)) (T (a ': bs)) (T as) (T bs)
tTail f (x :& xs) = (x :&) <$> f xs
tAppend :: T as -> T bs -> T (as ++ bs)
tAppend TNil ys = ys
tAppend (x :& xs) ys = x :& tAppend xs ys
infixr 5 `tAppend`
tSplit :: Length as -> T (as ++ bs) -> (T as, T bs)
tSplit LZ xs = (TNil, xs)
tSplit (LS l) (x :& xs) = first (x :&) . tSplit l $ xs
tTake :: forall as bs cs. Length as -> Lens (T (as ++ bs)) (T (cs ++ bs)) (T as) (T cs)
tTake l f (tSplit l->(xs,ys)) = flip (tAppend @cs @bs) ys <$> f xs
tDrop :: forall as bs cs. Length as -> Lens (T (as ++ bs)) (T (as ++ cs)) (T bs) (T cs)
tDrop l f (tSplit l->(xs,ys)) = tAppend xs <$> f ys
tProd :: T as -> Tuple as
tProd TNil = Ø
tProd (x :& xs) = x ::< tProd xs
prodT :: Tuple as -> T as
prodT Ø = TNil
prodT (I x :< xs) = x :& prodT xs
instance Num T0 where
_ + _ = T0
_ - _ = T0
_ * _ = T0
negate _ = T0
abs _ = T0
signum _ = T0
fromInteger _ = T0
instance Fractional T0 where
_ / _ = T0
recip _ = T0
fromRational _ = T0
instance Floating T0 where
pi = T0
_ ** _ = T0
logBase _ _ = T0
exp _ = T0
log _ = T0
sqrt _ = T0
sin _ = T0
cos _ = T0
asin _ = T0
acos _ = T0
atan _ = T0
sinh _ = T0
cosh _ = T0
asinh _ = T0
acosh _ = T0
atanh _ = T0
instance Semigroup T0 where
_ <> _ = T0
instance Monoid T0 where
mempty = T0
mappend = (<>)
instance (Num a, Num b) => Num (T2 a b) where
T2 x1 y1 + T2 x2 y2 = T2 (x1 + x2) (y1 + y2)
T2 x1 y1 - T2 x2 y2 = T2 (x1 - x2) (y1 - y2)
T2 x1 y1 * T2 x2 y2 = T2 (x1 * x2) (y1 * y2)
negate (T2 x y) = T2 (negate x) (negate y)
abs (T2 x y) = T2 (abs x) (abs y)
signum (T2 x y) = T2 (signum x) (signum y)
fromInteger x = T2 (fromInteger x) (fromInteger x)
instance (Fractional a, Fractional b) => Fractional (T2 a b) where
T2 x1 y1 / T2 x2 y2 = T2 (x1 / x2) (y1 / y2)
recip (T2 x y) = T2 (recip x) (recip y)
fromRational x = T2 (fromRational x) (fromRational x)
instance (Floating a, Floating b) => Floating (T2 a b) where
pi = T2 pi pi
T2 x1 y1 ** T2 x2 y2 = T2 (x1 ** x2) (y1 ** y2)
logBase (T2 x1 y1) (T2 x2 y2) = T2 (logBase x1 x2) (logBase y1 y2)
exp (T2 x y) = T2 (exp x) (exp y)
log (T2 x y) = T2 (log x) (log y)
sqrt (T2 x y) = T2 (sqrt x) (sqrt y)
sin (T2 x y) = T2 (sin x) (sin y)
cos (T2 x y) = T2 (cos x) (cos y)
asin (T2 x y) = T2 (asin x) (asin y)
acos (T2 x y) = T2 (acos x) (acos y)
atan (T2 x y) = T2 (atan x) (atan y)
sinh (T2 x y) = T2 (sinh x) (sinh y)
cosh (T2 x y) = T2 (cosh x) (cosh y)
asinh (T2 x y) = T2 (asinh x) (asinh y)
acosh (T2 x y) = T2 (acosh x) (acosh y)
atanh (T2 x y) = T2 (atanh x) (atanh y)
instance (Semigroup a, Semigroup b) => Semigroup (T2 a b) where
T2 x1 y1 <> T2 x2 y2 = T2 (x1 <> x2) (y1 <> y2)
#if MIN_VERSION_base(4,11,0)
instance (Monoid a, Monoid b) => Monoid (T2 a b) where
#else
instance (Semigroup a, Semigroup b, Monoid a, Monoid b) => Monoid (T2 a b) where
#endif
mappend = (<>)
mempty = T2 mempty mempty
instance (Num a, Num b, Num c) => Num (T3 a b c) where
T3 x1 y1 z1 + T3 x2 y2 z2 = T3 (x1 + x2) (y1 + y2) (z1 + z2)
T3 x1 y1 z1 - T3 x2 y2 z2 = T3 (x1 - x2) (y1 - y2) (z1 + z2)
T3 x1 y1 z1 * T3 x2 y2 z2 = T3 (x1 * x2) (y1 * y2) (z1 + z2)
negate (T3 x y z) = T3 (negate x) (negate y) (negate z)
abs (T3 x y z) = T3 (abs x) (abs y) (abs z)
signum (T3 x y z) = T3 (signum x) (signum y) (signum z)
fromInteger x = T3 (fromInteger x) (fromInteger x) (fromInteger x)
instance (Fractional a, Fractional b, Fractional c) => Fractional (T3 a b c) where
T3 x1 y1 z1 / T3 x2 y2 z2 = T3 (x1 / x2) (y1 / y2) (z1 / z2)
recip (T3 x y z) = T3 (recip x) (recip y) (recip z)
fromRational x = T3 (fromRational x) (fromRational x) (fromRational x)
instance (Floating a, Floating b, Floating c) => Floating (T3 a b c) where
pi = T3 pi pi pi
T3 x1 y1 z1 ** T3 x2 y2 z2 = T3 (x1 ** x2) (y1 ** y2) (z1 ** z2)
logBase (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (logBase x1 x2) (logBase y1 y2) (logBase z1 z2)
exp (T3 x y z) = T3 (exp x) (exp y) (exp z)
log (T3 x y z) = T3 (log x) (log y) (log z)
sqrt (T3 x y z) = T3 (sqrt x) (sqrt y) (sqrt z)
sin (T3 x y z) = T3 (sin x) (sin y) (sin z)
cos (T3 x y z) = T3 (cos x) (cos y) (cos z)
asin (T3 x y z) = T3 (asin x) (asin y) (asin z)
acos (T3 x y z) = T3 (acos x) (acos y) (acos z)
atan (T3 x y z) = T3 (atan x) (atan y) (atan z)
sinh (T3 x y z) = T3 (sinh x) (sinh y) (sinh z)
cosh (T3 x y z) = T3 (cosh x) (cosh y) (cosh z)
asinh (T3 x y z) = T3 (asinh x) (asinh y) (asinh z)
acosh (T3 x y z) = T3 (acosh x) (acosh y) (acosh z)
atanh (T3 x y z) = T3 (atanh x) (atanh y) (atanh z)
instance (Semigroup a, Semigroup b, Semigroup c) => Semigroup (T3 a b c) where
T3 x1 y1 z1 <> T3 x2 y2 z2 = T3 (x1 <> x2) (y1 <> y2) (z1 <> z2)
#if MIN_VERSION_base(4,11,0)
instance (Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
#else
instance (Semigroup a, Semigroup b, Semigroup c, Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
#endif
mappend = (<>)
mempty = T3 mempty mempty mempty
constT
:: forall c as. ListC (c <$> as)
=> (forall a. c a => a)
-> Length as
-> T as
constT x = go
where
go :: forall bs. ListC (c <$> bs) => Length bs -> T bs
go LZ = TNil
go (LS l) = x :& go l
mapT
:: forall c as. ListC (c <$> as)
=> (forall a. c a => a -> a)
-> T as
-> T as
mapT f = go
where
go :: forall bs. ListC (c <$> bs) => T bs -> T bs
go TNil = TNil
go (x :& xs) = f x :& go xs
zipT
:: forall c as. ListC (c <$> as)
=> (forall a. c a => a -> a -> a)
-> T as
-> T as
-> T as
zipT f = go
where
go :: forall bs. ListC (c <$> bs) => T bs -> T bs -> T bs
go TNil TNil = TNil
go (x :& xs) (y :& ys) = f x y :& go xs ys
instance (Known Length as, ListC (Num <$> as)) => Num (T as) where
(+) = zipT @Num (+)
(-) = zipT @Num (-)
(*) = zipT @Num (*)
negate = mapT @Num negate
abs = mapT @Num abs
signum = mapT @Num signum
fromInteger x = constT @Num (fromInteger x) known
instance (Known Length as, ListC (Num <$> as), ListC (Fractional <$> as)) => Fractional (T as) where
(/) = zipT @Fractional (/)
recip = mapT @Fractional recip
fromRational x = constT @Fractional (fromRational x) known
instance (Known Length as, ListC (Num <$> as), ListC (Fractional <$> as), ListC (Floating <$> as))
=> Floating (T as) where
pi = constT @Floating pi known
(**) = zipT @Floating (**)
logBase = zipT @Floating logBase
exp = mapT @Floating exp
log = mapT @Floating log
sqrt = mapT @Floating sqrt
sin = mapT @Floating sin
cos = mapT @Floating cos
asin = mapT @Floating asin
acos = mapT @Floating acos
atan = mapT @Floating atan
sinh = mapT @Floating sinh
cosh = mapT @Floating cosh
asinh = mapT @Floating asinh
acosh = mapT @Floating acosh
atanh = mapT @Floating atanh
instance ListC (Semigroup <$> as) => Semigroup (T as) where
(<>) = zipT @Semigroup (<>)
instance (Known Length as, ListC (Semigroup <$> as), ListC (Monoid <$> as)) => Monoid (T as) where
mempty = constT @Monoid mempty known
mappend = (<>)