{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Numeric.Backprop.Tuple (
T0(..)
, T2(..)
, t2Tup, tupT2
, uncurryT2, curryT2
, t2_1, t2_2
, T3(..)
, t3Tup, tupT3
, t3_1, t3_2, t3_3
, uncurryT3, curryT3
, T(..)
, indexT
, tOnly, onlyT, tSplit, tAppend, tProd, prodT
, tIx, tHead, tTail, tTake, tDrop
, constT, mapT, zipT
) where
import Control.DeepSeq
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Data
import Data.Kind
import Data.Type.Combinator
import Data.Type.Index
import Data.Type.Length
import Data.Type.Product
import GHC.Generics (Generic)
import Lens.Micro
import Lens.Micro.Internal hiding (Index)
import System.Random
import Type.Class.Known
import Type.Family.List
import qualified Data.Binary as Bi
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
data T0 = T0
deriving (Show, Read, Eq, Ord, Generic, Data)
data T2 a b = T2 !a !b
deriving (Show, Read, Eq, Ord, Generic, Functor, Data, Typeable)
data T3 a b c = T3 !a !b !c
deriving (Show, Read, Eq, Ord, Generic, Functor, Data, Typeable)
data T :: [Type] -> Type where
TNil :: T '[]
(:&) :: !a -> !(T as) -> T (a ': as)
deriving instance ListC (Show <$> as) => Show (T as)
deriving instance ListC (Eq <$> as) => Eq (T as)
deriving instance (ListC (Eq <$> as), ListC (Ord <$> as)) => Ord (T as)
deriving instance Typeable (T as)
deriving instance Typeable T0
deriving instance Typeable (T2 a b)
deriving instance Typeable (T3 a b c)
instance NFData T0
instance (NFData a, NFData b) => NFData (T2 a b)
instance (NFData a, NFData b, NFData c) => NFData (T3 a b c)
instance ListC (NFData <$> as) => NFData (T as) where
rnf = \case
TNil -> ()
x :& xs -> rnf x `seq` rnf xs
instance Random T0 where
randomR _ = (T0,)
random = (T0,)
randomRs _ _ = repeat T0
randoms _ = repeat T0
randomIO = pure T0
instance (Random a, Random b) => Random (T2 a b) where
randomR (T2 lx ly, T2 ux uy) = runState $
T2 <$> state (randomR (lx, ux))
<*> state (randomR (ly, uy))
random = runState $
T2 <$> state random <*> state random
instance (Random a, Random b, Random c) => Random (T3 a b c) where
randomR (T3 lx ly lz, T3 ux uy uz) = runState $
T3 <$> state (randomR (lx, ux))
<*> state (randomR (ly, uy))
<*> state (randomR (lz, uz))
random = runState $
T3 <$> state random <*> state random <*> state random
instance Bi.Binary T0
instance (Bi.Binary a, Bi.Binary b) => Bi.Binary (T2 a b)
instance (Bi.Binary a, Bi.Binary b, Bi.Binary c) => Bi.Binary (T3 a b c)
instance Bifunctor T2 where
bimap f g (T2 x y) = T2 (f x) (g y)
instance Bifunctor (T3 a) where
bimap f g (T3 x y z) = T3 x (f y) (g z)
t2Tup :: T2 a b -> (a, b)
t2Tup (T2 x y) = (x, y)
tupT2 :: (a, b) -> T2 a b
tupT2 (x, y) = T2 x y
t3Tup :: T3 a b c -> (a, b, c)
t3Tup (T3 x y z) = (x, y, z)
tupT3 :: (a, b, c) -> T3 a b c
tupT3 (x, y, z) = T3 x y z
onlyT :: a -> T '[a]
onlyT = (:& TNil)
tOnly :: T '[a] -> a
tOnly (x :& _) = x
uncurryT2 :: (a -> b -> c) -> T2 a b -> c
uncurryT2 f (T2 x y) = f x y
curryT2 :: (T2 a b -> c) -> a -> b -> c
curryT2 f x y = f (T2 x y)
uncurryT3 :: (a -> b -> c -> d) -> T3 a b c -> d
uncurryT3 f (T3 x y z) = f x y z
curryT3 :: (T3 a b c -> d) -> a -> b -> c -> d
curryT3 f x y z = f (T3 x y z)
instance Field1 (T2 a b) (T2 a' b) a a' where
_1 = t2_1
instance Field2 (T2 a b) (T2 a b') b b' where
_2 = t2_2
instance Field1 (T3 a b c) (T3 a' b c) a a' where
_1 = t3_1
instance Field2 (T3 a b c) (T3 a b' c) b b' where
_2 = t3_2
instance Field3 (T3 a b c) (T3 a b c') c c' where
_3 = t3_3
instance Field1 (T (a ': as)) (T (a ': as)) a a where
_1 = tIx IZ
instance Field2 (T (a ': b ': as)) (T (a ': b ': as)) b b where
_2 = tIx (IS IZ)
instance Field3 (T (a ': b ': c ': as)) (T (a ': b ': c ': as)) c c where
_3 = tIx (IS (IS IZ))
t2_1 :: Lens (T2 a b) (T2 a' b) a a'
t2_1 f (T2 x y) = (`T2` y) <$> f x
t2_2 :: Lens (T2 a b) (T2 a b') b b'
t2_2 f (T2 x y) = T2 x <$> f y
t3_1 :: Lens (T3 a b c) (T3 a' b c) a a'
t3_1 f (T3 x y z) = (\x' -> T3 x' y z) <$> f x
t3_2 :: Lens (T3 a b c) (T3 a b' c) b b'
t3_2 f (T3 x y z) = (\y' -> T3 x y' z) <$> f y
t3_3 :: Lens (T3 a b c) (T3 a b c') c c'
t3_3 f (T3 x y z) = T3 x y <$> f z
indexT :: Index as a -> T as -> a
indexT = flip (^.) . tIx
tIx :: Index as a -> Lens' (T as) a
tIx IZ f (x :& xs) = (:& xs) <$> f x
tIx (IS i) f (x :& xs) = (x :&) <$> tIx i f xs
tHead :: Lens (T (a ': as)) (T (b ': as)) a b
tHead f (x :& xs) = (:& xs) <$> f x
tTail :: Lens (T (a ': as)) (T (a ': bs)) (T as) (T bs)
tTail f (x :& xs) = (x :&) <$> f xs
tAppend :: T as -> T bs -> T (as ++ bs)
tAppend TNil ys = ys
tAppend (x :& xs) ys = x :& tAppend xs ys
infixr 5 `tAppend`
tSplit :: Length as -> T (as ++ bs) -> (T as, T bs)
tSplit LZ xs = (TNil, xs)
tSplit (LS l) (x :& xs) = first (x :&) . tSplit l $ xs
tTake :: forall as bs cs. Length as -> Lens (T (as ++ bs)) (T (cs ++ bs)) (T as) (T cs)
tTake l f (tSplit l->(xs,ys)) = flip (tAppend @cs @bs) ys <$> f xs
tDrop :: forall as bs cs. Length as -> Lens (T (as ++ bs)) (T (as ++ cs)) (T bs) (T cs)
tDrop l f (tSplit l->(xs,ys)) = tAppend xs <$> f ys
tProd :: T as -> Tuple as
tProd TNil = Ø
tProd (x :& xs) = x ::< tProd xs
prodT :: Tuple as -> T as
prodT Ø = TNil
prodT (I x :< xs) = x :& prodT xs
instance Num T0 where
_ + _ = T0
_ - _ = T0
_ * _ = T0
negate _ = T0
abs _ = T0
signum _ = T0
fromInteger _ = T0
instance Fractional T0 where
_ / _ = T0
recip _ = T0
fromRational _ = T0
instance Floating T0 where
pi = T0
_ ** _ = T0
logBase _ _ = T0
exp _ = T0
log _ = T0
sqrt _ = T0
sin _ = T0
cos _ = T0
asin _ = T0
acos _ = T0
atan _ = T0
sinh _ = T0
cosh _ = T0
asinh _ = T0
acosh _ = T0
atanh _ = T0
instance Semigroup T0 where
_ <> _ = T0
instance Monoid T0 where
mempty = T0
mappend = (<>)
instance (Num a, Num b) => Num (T2 a b) where
T2 x1 y1 + T2 x2 y2 = T2 (x1 + x2) (y1 + y2)
T2 x1 y1 - T2 x2 y2 = T2 (x1 - x2) (y1 - y2)
T2 x1 y1 * T2 x2 y2 = T2 (x1 * x2) (y1 * y2)
negate (T2 x y) = T2 (negate x) (negate y)
abs (T2 x y) = T2 (abs x) (abs y)
signum (T2 x y) = T2 (signum x) (signum y)
fromInteger x = T2 (fromInteger x) (fromInteger x)
instance (Fractional a, Fractional b) => Fractional (T2 a b) where
T2 x1 y1 / T2 x2 y2 = T2 (x1 / x2) (y1 / y2)
recip (T2 x y) = T2 (recip x) (recip y)
fromRational x = T2 (fromRational x) (fromRational x)
instance (Floating a, Floating b) => Floating (T2 a b) where
pi = T2 pi pi
T2 x1 y1 ** T2 x2 y2 = T2 (x1 ** x2) (y1 ** y2)
logBase (T2 x1 y1) (T2 x2 y2) = T2 (logBase x1 x2) (logBase y1 y2)
exp (T2 x y) = T2 (exp x) (exp y)
log (T2 x y) = T2 (log x) (log y)
sqrt (T2 x y) = T2 (sqrt x) (sqrt y)
sin (T2 x y) = T2 (sin x) (sin y)
cos (T2 x y) = T2 (cos x) (cos y)
asin (T2 x y) = T2 (asin x) (asin y)
acos (T2 x y) = T2 (acos x) (acos y)
atan (T2 x y) = T2 (atan x) (atan y)
sinh (T2 x y) = T2 (sinh x) (sinh y)
cosh (T2 x y) = T2 (cosh x) (cosh y)
asinh (T2 x y) = T2 (asinh x) (asinh y)
acosh (T2 x y) = T2 (acosh x) (acosh y)
atanh (T2 x y) = T2 (atanh x) (atanh y)
instance (Semigroup a, Semigroup b) => Semigroup (T2 a b) where
T2 x1 y1 <> T2 x2 y2 = T2 (x1 <> x2) (y1 <> y2)
#if MIN_VERSION_base(4,11,0)
instance (Monoid a, Monoid b) => Monoid (T2 a b) where
#else
instance (Semigroup a, Semigroup b, Monoid a, Monoid b) => Monoid (T2 a b) where
#endif
mappend = (<>)
mempty = T2 mempty mempty
instance (Num a, Num b, Num c) => Num (T3 a b c) where
T3 x1 y1 z1 + T3 x2 y2 z2 = T3 (x1 + x2) (y1 + y2) (z1 + z2)
T3 x1 y1 z1 - T3 x2 y2 z2 = T3 (x1 - x2) (y1 - y2) (z1 + z2)
T3 x1 y1 z1 * T3 x2 y2 z2 = T3 (x1 * x2) (y1 * y2) (z1 + z2)
negate (T3 x y z) = T3 (negate x) (negate y) (negate z)
abs (T3 x y z) = T3 (abs x) (abs y) (abs z)
signum (T3 x y z) = T3 (signum x) (signum y) (signum z)
fromInteger x = T3 (fromInteger x) (fromInteger x) (fromInteger x)
instance (Fractional a, Fractional b, Fractional c) => Fractional (T3 a b c) where
T3 x1 y1 z1 / T3 x2 y2 z2 = T3 (x1 / x2) (y1 / y2) (z1 / z2)
recip (T3 x y z) = T3 (recip x) (recip y) (recip z)
fromRational x = T3 (fromRational x) (fromRational x) (fromRational x)
instance (Floating a, Floating b, Floating c) => Floating (T3 a b c) where
pi = T3 pi pi pi
T3 x1 y1 z1 ** T3 x2 y2 z2 = T3 (x1 ** x2) (y1 ** y2) (z1 ** z2)
logBase (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (logBase x1 x2) (logBase y1 y2) (logBase z1 z2)
exp (T3 x y z) = T3 (exp x) (exp y) (exp z)
log (T3 x y z) = T3 (log x) (log y) (log z)
sqrt (T3 x y z) = T3 (sqrt x) (sqrt y) (sqrt z)
sin (T3 x y z) = T3 (sin x) (sin y) (sin z)
cos (T3 x y z) = T3 (cos x) (cos y) (cos z)
asin (T3 x y z) = T3 (asin x) (asin y) (asin z)
acos (T3 x y z) = T3 (acos x) (acos y) (acos z)
atan (T3 x y z) = T3 (atan x) (atan y) (atan z)
sinh (T3 x y z) = T3 (sinh x) (sinh y) (sinh z)
cosh (T3 x y z) = T3 (cosh x) (cosh y) (cosh z)
asinh (T3 x y z) = T3 (asinh x) (asinh y) (asinh z)
acosh (T3 x y z) = T3 (acosh x) (acosh y) (acosh z)
atanh (T3 x y z) = T3 (atanh x) (atanh y) (atanh z)
instance (Semigroup a, Semigroup b, Semigroup c) => Semigroup (T3 a b c) where
T3 x1 y1 z1 <> T3 x2 y2 z2 = T3 (x1 <> x2) (y1 <> y2) (z1 <> z2)
#if MIN_VERSION_base(4,11,0)
instance (Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
#else
instance (Semigroup a, Semigroup b, Semigroup c, Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
#endif
mappend = (<>)
mempty = T3 mempty mempty mempty
constT
:: forall c as. ListC (c <$> as)
=> (forall a. c a => a)
-> Length as
-> T as
constT x = go
where
go :: forall bs. ListC (c <$> bs) => Length bs -> T bs
go LZ = TNil
go (LS l) = x :& go l
mapT
:: forall c as. ListC (c <$> as)
=> (forall a. c a => a -> a)
-> T as
-> T as
mapT f = go
where
go :: forall bs. ListC (c <$> bs) => T bs -> T bs
go TNil = TNil
go (x :& xs) = f x :& go xs
zipT
:: forall c as. ListC (c <$> as)
=> (forall a. c a => a -> a -> a)
-> T as
-> T as
-> T as
zipT f = go
where
go :: forall bs. ListC (c <$> bs) => T bs -> T bs -> T bs
go TNil TNil = TNil
go (x :& xs) (y :& ys) = f x y :& go xs ys
instance (Known Length as, ListC (Num <$> as)) => Num (T as) where
(+) = zipT @Num (+)
(-) = zipT @Num (-)
(*) = zipT @Num (*)
negate = mapT @Num negate
abs = mapT @Num abs
signum = mapT @Num signum
fromInteger x = constT @Num (fromInteger x) known
instance (Known Length as, ListC (Num <$> as), ListC (Fractional <$> as)) => Fractional (T as) where
(/) = zipT @Fractional (/)
recip = mapT @Fractional recip
fromRational x = constT @Fractional (fromRational x) known
instance (Known Length as, ListC (Num <$> as), ListC (Fractional <$> as), ListC (Floating <$> as))
=> Floating (T as) where
pi = constT @Floating pi known
(**) = zipT @Floating (**)
logBase = zipT @Floating logBase
exp = mapT @Floating exp
log = mapT @Floating log
sqrt = mapT @Floating sqrt
sin = mapT @Floating sin
cos = mapT @Floating cos
asin = mapT @Floating asin
acos = mapT @Floating acos
atan = mapT @Floating atan
sinh = mapT @Floating sinh
cosh = mapT @Floating cosh
asinh = mapT @Floating asinh
acosh = mapT @Floating acosh
atanh = mapT @Floating atanh
instance ListC (Semigroup <$> as) => Semigroup (T as) where
(<>) = zipT @Semigroup (<>)
instance (Known Length as, ListC (Semigroup <$> as), ListC (Monoid <$> as)) => Monoid (T as) where
mempty = constT @Monoid mempty known
mappend = (<>)
instance (Known Length as, ListC (Bi.Binary <$> as)) => Bi.Binary (T as) where
put = \case
TNil -> pure ()
x :& xs -> do
Bi.put x
Bi.put xs
get = getT known
getT :: ListC (Bi.Binary <$> as) => Length as -> Bi.Get (T as)
getT = \case
LZ -> pure TNil
LS l -> do
x <- Bi.get
xs <- getT l
pure (x :& xs)
instance (Known Length as, ListC (Random <$> as)) => Random (T as) where
randomR (l, u) = runState (randomRT l u)
random = runState (randomT known)
randomRT
:: (ListC (Random <$> as), RandomGen g)
=> T as
-> T as
-> State g (T as)
randomRT = \case
TNil -> \case
TNil -> pure TNil
lx :& lxs -> \case
ux :& uxs -> (:&) <$> state (randomR (lx, ux)) <*> randomRT lxs uxs
randomT
:: (ListC (Random <$> as), RandomGen g)
=> Length as
-> State g (T as)
randomT = \case
LZ -> pure TNil
LS l -> (:&) <$> state random <*> randomT l