{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.Backprop.Explicit (
BVar, W, Backprop(..), ABP(..), NumBP(..)
, ZeroFunc(..), zfNum, zfNums, zeroFunc, zeroFuncs, zfFunctor
, AddFunc(..), afNum, afNums, addFunc, addFuncs
, OneFunc(..), ofNum, ofNums, oneFunc, oneFuncs, ofFunctor
, backprop, evalBP, gradBP, backpropWith
, evalBP0
, backprop2, evalBP2, gradBP2, backpropWith2
, backpropN, evalBPN, gradBPN, backpropWithN, Every
, constVar, auto, coerceVar
, viewVar, setVar
, sequenceVar, collectVar
, previewVar, toListOfVar
, isoVar, isoVar2, isoVar3, isoVarN
, liftOp
, liftOp1, liftOp2, liftOp3
, splitBV
, joinBV
, BVGroup
, Op(..)
, op0, opConst, idOp
, opConst'
, op1, op2, op3
, opCoerce, opTup, opIso, opIsoN, opLens
, noGrad1, noGrad
, Prod(..), pattern (:>), only, head'
, Tuple, pattern (::<), only_
, I(..)
, Reifies
) where
import Data.Bifunctor
import Data.Reflection
import Data.Type.Index
import Data.Type.Length
import Data.Type.Product
import Data.Type.Util
import GHC.Generics as G
import Lens.Micro
import Numeric.Backprop.Class
import Numeric.Backprop.Internal
import Numeric.Backprop.Op
import Type.Class.Higher
import Type.Class.Known
import Type.Class.Witness
import Type.Family.List
import Unsafe.Coerce
zfNums :: (Every Num as, Known Length as) => Prod ZeroFunc as
zfNums = map1 (\i -> zfNum \\ every @_ @Num i) indices
zfFunctor :: (Backprop a, Functor f) => ZeroFunc (f a)
zfFunctor = ZF zeroFunctor
{-# INLINE zfFunctor #-}
afNums :: (Every Num as, Known Length as) => Prod AddFunc as
afNums = map1 (\i -> afNum \\ every @_ @Num i) indices
ofNums :: (Every Num as, Known Length as) => Prod OneFunc as
ofNums = map1 (\i -> ofNum \\ every @_ @Num i) indices
ofFunctor :: (Backprop a, Functor f) => OneFunc (f a)
ofFunctor = OF oneFunctor
{-# INLINE ofFunctor #-}
zeroFuncs :: (Every Backprop as, Known Length as) => Prod ZeroFunc as
zeroFuncs = map1 (\i -> zeroFunc \\ every @_ @Backprop i) indices
addFuncs :: (Every Backprop as, Known Length as) => Prod AddFunc as
addFuncs = map1 (\i -> addFunc \\ every @_ @Backprop i) indices
oneFuncs :: (Every Backprop as, Known Length as) => Prod OneFunc as
oneFuncs = map1 (\i -> oneFunc \\ every @_ @Backprop i) indices
auto :: a -> BVar s a
auto = constVar
{-# INLINE auto #-}
backpropWithN
:: Prod ZeroFunc as
-> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> (b, (b -> b) -> Tuple as)
backpropWithN zfs f = second (. OF) . backpropN zfs f
{-# INLINE backpropWithN #-}
backprop
:: ZeroFunc a
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> (b, OneFunc b -> a)
backprop zfa f = second ((getI . head') .)
. backpropN (zfa :< Ø) (f . head')
. only_
{-# INLINE backprop #-}
backpropWith
:: ZeroFunc a
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> (b, (b -> b) -> a)
backpropWith zfa f = second (. OF) . backprop zfa f
{-# INLINE backpropWith #-}
evalBP0 :: (forall s. Reifies s W => BVar s a) -> a
evalBP0 x = evalBPN (const x) Ø
{-# INLINE evalBP0 #-}
evalBP :: (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> b
evalBP f = evalBPN (f . head') . only_
{-# INLINE evalBP #-}
gradBP
:: ZeroFunc a
-> OneFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> a
gradBP zfa ofb f = ($ ofb) . snd . backprop zfa f
{-# INLINE gradBP #-}
gradBPN
:: Prod ZeroFunc as
-> OneFunc b
-> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> Tuple as
gradBPN zfas ofb f = ($ ofb) . snd . backpropN zfas f
{-# INLINE gradBPN #-}
backprop2
:: ZeroFunc a
-> ZeroFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (c, OneFunc c -> (a, b))
backprop2 zfa zfb f x y = second ((\(dx ::< dy ::< Ø) -> (dx, dy)) .) $
backpropN (zfa :< zfb :< Ø)
(\(x' :< y' :< Ø) -> f x' y')
(x ::< y ::< Ø)
{-# INLINE backprop2 #-}
backpropWith2
:: ZeroFunc a
-> ZeroFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (c, (c -> c) -> (a, b))
backpropWith2 zfa zfb f x = second (. OF) . backprop2 zfa zfb f x
{-# INLINE backpropWith2 #-}
evalBP2
:: (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> c
evalBP2 f x y = evalBPN (\(x' :< y' :< Ø) -> f x' y') (x ::< y ::< Ø)
{-# INLINE evalBP2 #-}
gradBP2
:: ZeroFunc a
-> ZeroFunc b
-> OneFunc c
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (a, b)
gradBP2 zfa zfb ofc f x = ($ ofc) . snd . backprop2 zfa zfb f x
{-# INLINE gradBP2 #-}
isoVar
:: Reifies s W
=> AddFunc a
-> ZeroFunc b
-> (a -> b)
-> (b -> a)
-> BVar s a
-> BVar s b
isoVar af z f g = liftOp1 af z (opIso f g)
{-# INLINE isoVar #-}
isoVar2
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc c
-> (a -> b -> c)
-> (c -> (a, b))
-> BVar s a
-> BVar s b
-> BVar s c
isoVar2 afa afb z f g = liftOp2 afa afb z (opIso2 f g)
{-# INLINE isoVar2 #-}
isoVar3
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> ZeroFunc d
-> (a -> b -> c -> d)
-> (d -> (a, b, c))
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
isoVar3 afa afb afc z f g = liftOp3 afa afb afc z (opIso3 f g)
{-# INLINE isoVar3 #-}
isoVarN
:: Reifies s W
=> Prod AddFunc as
-> ZeroFunc b
-> (Tuple as -> b)
-> (b -> Tuple as)
-> Prod (BVar s) as
-> BVar s b
isoVarN afs z f g = liftOp afs z (opIsoN f g)
{-# INLINE isoVarN #-}
class BVGroup s as i o | o -> i, i -> as where
gsplitBV :: Prod AddFunc as -> Prod ZeroFunc as -> BVar s (i ()) -> o ()
gjoinBV :: Prod AddFunc as -> Prod ZeroFunc as -> o () -> BVar s (i ())
instance BVGroup s '[] (K1 i a) (K1 i (BVar s a)) where
gsplitBV _ _ = K1 . coerceVar
{-# INLINE gsplitBV #-}
gjoinBV _ _ = coerceVar . unK1
{-# INLINE gjoinBV #-}
instance BVGroup s as i o
=> BVGroup s as (M1 p c i) (M1 p c o) where
gsplitBV afs zfs = M1 . gsplitBV afs zfs . coerceVar @_ @(i ())
{-# INLINE gsplitBV #-}
gjoinBV afs zfs = coerceVar @(i ()) . gjoinBV afs zfs . unM1
{-# INLINE gjoinBV #-}
instance BVGroup s '[] V1 V1 where
gsplitBV _ _ = unsafeCoerce
{-# INLINE gsplitBV #-}
gjoinBV _ _ = \case
{-# INLINE gjoinBV #-}
instance BVGroup s '[] U1 U1 where
gsplitBV _ _ _ = U1
{-# INLINE gsplitBV #-}
gjoinBV _ _ _ = constVar U1
{-# INLINE gjoinBV #-}
instance ( Reifies s W
, BVGroup s as i1 o1
, BVGroup s bs i2 o2
, cs ~ (as ++ bs)
, Known Length as
) => BVGroup s (i1 () ': i2 () ': cs) (i1 :*: i2) (o1 :*: o2) where
gsplitBV (afa :< afb :< afs) (zfa :< zfb :< zfs) xy = x :*: y
where
(afas, afbs) = splitProd known afs
(zfas, zfbs) = splitProd known zfs
x = gsplitBV afas zfas . viewVar afa zfa p1 $ xy
y = gsplitBV afbs zfbs . viewVar afb zfb p2 $ xy
{-# INLINE gsplitBV #-}
gjoinBV (afa :< afb :< afs) (zfa :< zfb :< zfs) (x :*: y)
= isoVar2 afa afb zfab (:*:) unP
(gjoinBV afas zfas x)
(gjoinBV afbs zfbs y)
where
zfab = ZF $ \(xx :*: yy) -> runZF zfa xx :*: runZF zfb yy
(afas, afbs) = splitProd known afs
(zfas, zfbs) = splitProd known zfs
unP (xx :*: yy) = (xx, yy)
{-# INLINE gjoinBV #-}
instance ( Reifies s W
, BVGroup s as i1 o1
, BVGroup s bs i2 o2
, cs ~ (as ++ bs)
, Known Length as
) => BVGroup s (i1 () ': i2 () ': cs) (i1 :+: i2) (o1 :+: o2) where
gsplitBV (afa :< afb :< afs) (zfa :< zfb :< zfs) xy =
case previewVar afa zfa s1 xy of
Just x -> L1 $ gsplitBV afas zfas x
Nothing -> case previewVar afb zfb s2 xy of
Just y -> R1 $ gsplitBV afbs zfbs y
Nothing -> error "Numeric.Backprop.gsplitBV: Internal error occurred"
where
(afas, afbs) = splitProd known afs
(zfas, zfbs) = splitProd known zfs
{-# INLINE gsplitBV #-}
gjoinBV (afa :< afb :< afs) (zfa :< zfb :< zfs) = \case
L1 x -> liftOp1 afa zf (op1 (\xx -> (L1 xx, \case L1 d -> d; R1 _ -> runZF zfa xx)))
(gjoinBV afas zfas x)
R1 y -> liftOp1 afb zf (op1 (\yy -> (R1 yy, \case L1 _ -> runZF zfb yy; R1 d -> d)))
(gjoinBV afbs zfbs y)
where
(afas, afbs) = splitProd known afs
(zfas, zfbs) = splitProd known zfs
zf = ZF $ \case
L1 xx -> L1 $ runZF zfa xx
R1 yy -> R1 $ runZF zfb yy
{-# INLINE gjoinBV #-}
splitBV
:: forall z f s as.
( Generic (z f)
, Generic (z (BVar s))
, BVGroup s as (Rep (z f)) (Rep (z (BVar s)))
, Reifies s W
)
=> AddFunc (Rep (z f) ())
-> Prod AddFunc as
-> ZeroFunc (Rep (z f) ())
-> Prod ZeroFunc as
-> BVar s (z f)
-> z (BVar s)
splitBV af afs zf zfs =
G.to
. gsplitBV afs zfs
. viewVar af zf (lens (from @(z f) @()) (const G.to))
{-# INLINE splitBV #-}
joinBV
:: forall z f s as.
( Generic (z f)
, Generic (z (BVar s))
, BVGroup s as (Rep (z f)) (Rep (z (BVar s)))
, Reifies s W
)
=> AddFunc (z f)
-> Prod AddFunc as
-> ZeroFunc (z f)
-> Prod ZeroFunc as
-> z (BVar s)
-> BVar s (z f)
joinBV af afs zf zfs =
viewVar af zf (lens G.to (const from))
. gjoinBV afs zfs
. from @(z (BVar s)) @()
{-# INLINE joinBV #-}