{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.Backprop.Explicit (
BVar, W, Backprop(..), ABP(..), NumBP(..)
, ZeroFunc(..), zfNum, zfNums, zeroFunc, zeroFuncs, zfFunctor
, AddFunc(..), afNum, afNums, addFunc, addFuncs
, OneFunc(..), ofNum, ofNums, oneFunc, oneFuncs, ofFunctor
, backprop, evalBP, gradBP, backpropWith
, evalBP0
, backprop2, evalBP2, gradBP2, backpropWith2
, backpropN, evalBPN, gradBPN, backpropWithN, RPureConstrained
, constVar, auto, coerceVar
, viewVar, setVar, overVar
, sequenceVar, collectVar
, previewVar, toListOfVar
, isoVar, isoVar2, isoVar3, isoVarN
, liftOp
, liftOp1, liftOp2, liftOp3
, splitBV
, joinBV
, BVGroup
, Op(..)
, op0, opConst, idOp
, bpOp
, op1, op2, op3
, opCoerce, opTup, opIso, opIsoN, opLens
, noGrad1, noGrad
, Reifies
) where
import Data.Bifunctor
import Data.Functor.Identity
import Data.Reflection
import Data.Type.Util
import Data.Vinyl.Core
import Data.Vinyl.TypeLevel
import GHC.Generics as G
import Lens.Micro
import Numeric.Backprop.Class
import Numeric.Backprop.Internal
import Numeric.Backprop.Op
import Unsafe.Coerce
zfNums :: RPureConstrained Num as => Rec ZeroFunc as
zfNums = rpureConstrained @Num zfNum
zfFunctor :: (Backprop a, Functor f) => ZeroFunc (f a)
zfFunctor = ZF zeroFunctor
{-# INLINE zfFunctor #-}
afNums :: RPureConstrained Num as => Rec AddFunc as
afNums = rpureConstrained @Num afNum
ofNums :: RPureConstrained Num as => Rec OneFunc as
ofNums = rpureConstrained @Num ofNum
ofFunctor :: (Backprop a, Functor f) => OneFunc (f a)
ofFunctor = OF oneFunctor
{-# INLINE ofFunctor #-}
zeroFuncs :: RPureConstrained Backprop as => Rec ZeroFunc as
zeroFuncs = rpureConstrained @Backprop zeroFunc
addFuncs :: RPureConstrained Backprop as => Rec AddFunc as
addFuncs = rpureConstrained @Backprop addFunc
oneFuncs :: RPureConstrained Backprop as => Rec OneFunc as
oneFuncs = rpureConstrained @Backprop oneFunc
auto :: a -> BVar s a
auto = constVar
{-# INLINE auto #-}
backpropN
:: forall as b. ()
=> Rec ZeroFunc as
-> OneFunc b
-> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> (b, Rec Identity as)
backpropN zfs ob f xs = case backpropWithN zfs f xs of
(y, g) -> (y, g (runOF ob y))
{-# INLINE backpropN #-}
backprop
:: ZeroFunc a
-> OneFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> (b, a)
backprop zfa ofb f = second (\case Identity x :& RNil -> x)
. backpropN (zfa :& RNil) ofb (f . (\case x :& RNil -> x))
. (:& RNil)
. Identity
{-# INLINE backprop #-}
backpropWith
:: ZeroFunc a
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> (b, b -> a)
backpropWith zfa f = second ((\case Identity x :& RNil -> x) .)
. backpropWithN (zfa :& RNil) (f . (\case x :& RNil -> x))
. (:& RNil)
. Identity
{-# INLINE backpropWith #-}
evalBP0 :: (forall s. Reifies s W => BVar s a) -> a
evalBP0 x = evalBPN (const x) RNil
{-# INLINE evalBP0 #-}
evalBP :: (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> b
evalBP f = evalBPN (f . (\case x :& RNil -> x)) . (:& RNil) . Identity
{-# INLINE evalBP #-}
gradBP
:: ZeroFunc a
-> OneFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> a
gradBP zfa ofb f = snd . backprop zfa ofb f
{-# INLINE gradBP #-}
gradBPN
:: Rec ZeroFunc as
-> OneFunc b
-> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> Rec Identity as
gradBPN zfas ofb f = snd . backpropN zfas ofb f
{-# INLINE gradBPN #-}
backprop2
:: ZeroFunc a
-> ZeroFunc b
-> OneFunc c
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (c, (a, b))
backprop2 zfa zfb ofc f x y = second (\(Identity dx :& Identity dy :& RNil) -> (dx, dy)) $
backpropN (zfa :& zfb :& RNil) ofc
(\(x' :& y' :& RNil) -> f x' y')
(Identity x :& Identity y :& RNil)
{-# INLINE backprop2 #-}
backpropWith2
:: ZeroFunc a
-> ZeroFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (c, c -> (a, b))
backpropWith2 zfa zfb f x y = second ((\(Identity dx :& Identity dy :& RNil) -> (dx, dy)) .) $
backpropWithN (zfa :& zfb :& RNil)
(\(x' :& y' :& RNil) -> f x' y')
(Identity x :& Identity y :& RNil)
{-# INLINE backpropWith2 #-}
evalBP2
:: (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> c
evalBP2 f x y = evalBPN (\(x' :& y' :& RNil) -> f x' y') $ Identity x
:& Identity y
:& RNil
{-# INLINE evalBP2 #-}
gradBP2
:: ZeroFunc a
-> ZeroFunc b
-> OneFunc c
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (a, b)
gradBP2 zfa zfb ofc f x = snd . backprop2 zfa zfb ofc f x
{-# INLINE gradBP2 #-}
bpOp
:: Rec ZeroFunc as
-> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Op as b
bpOp zfs f = Op (backpropWithN zfs f)
{-# INLINE bpOp #-}
overVar
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc a
-> ZeroFunc b
-> Lens' b a
-> (BVar s a -> BVar s a)
-> BVar s b
-> BVar s b
overVar afa afb zfa zfb l f x = setVar afa afb zfa l (f (viewVar afa zfb l x)) x
{-# INLINE overVar #-}
isoVar
:: Reifies s W
=> AddFunc a
-> (a -> b)
-> (b -> a)
-> BVar s a
-> BVar s b
isoVar af f g = liftOp1 af (opIso f g)
{-# INLINE isoVar #-}
isoVar2
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> (a -> b -> c)
-> (c -> (a, b))
-> BVar s a
-> BVar s b
-> BVar s c
isoVar2 afa afb f g = liftOp2 afa afb (opIso2 f g)
{-# INLINE isoVar2 #-}
isoVar3
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> (a -> b -> c -> d)
-> (d -> (a, b, c))
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
isoVar3 afa afb afc f g = liftOp3 afa afb afc (opIso3 f g)
{-# INLINE isoVar3 #-}
isoVarN
:: Reifies s W
=> Rec AddFunc as
-> (Rec Identity as -> b)
-> (b -> Rec Identity as)
-> Rec (BVar s) as
-> BVar s b
isoVarN afs f g = liftOp afs (opIsoN f g)
{-# INLINE isoVarN #-}
class BVGroup s as i o | o -> i, i -> as where
gsplitBV :: Rec AddFunc as -> Rec ZeroFunc as -> BVar s (i ()) -> o ()
gjoinBV :: Rec AddFunc as -> Rec ZeroFunc as -> o () -> BVar s (i ())
instance BVGroup s '[] (K1 i a) (K1 i (BVar s a)) where
gsplitBV _ _ = K1 . coerceVar
{-# INLINE gsplitBV #-}
gjoinBV _ _ = coerceVar . unK1
{-# INLINE gjoinBV #-}
instance BVGroup s as i o
=> BVGroup s as (M1 p c i) (M1 p c o) where
gsplitBV afs zfs = M1 . gsplitBV afs zfs . coerceVar @_ @(i ())
{-# INLINE gsplitBV #-}
gjoinBV afs zfs = coerceVar @(i ()) . gjoinBV afs zfs . unM1
{-# INLINE gjoinBV #-}
instance BVGroup s '[] V1 V1 where
gsplitBV _ _ = unsafeCoerce
{-# INLINE gsplitBV #-}
gjoinBV _ _ = \case
{-# INLINE gjoinBV #-}
instance BVGroup s '[] U1 U1 where
gsplitBV _ _ _ = U1
{-# INLINE gsplitBV #-}
gjoinBV _ _ _ = constVar U1
{-# INLINE gjoinBV #-}
instance ( Reifies s W
, BVGroup s as i1 o1
, BVGroup s bs i2 o2
, cs ~ (as ++ bs)
, RecApplicative as
) => BVGroup s (i1 () ': i2 () ': cs) (i1 :*: i2) (o1 :*: o2) where
gsplitBV (afa :& afb :& afs) (zfa :& zfb :& zfs) xy = x :*: y
where
(afas, afbs) = splitRec afs
(zfas, zfbs) = splitRec zfs
zfab = ZF $ \(xx :*: yy) -> runZF zfa xx :*: runZF zfb yy
x = gsplitBV afas zfas . viewVar afa zfab p1 $ xy
y = gsplitBV afbs zfbs . viewVar afb zfab p2 $ xy
{-# INLINE gsplitBV #-}
gjoinBV (afa :& afb :& afs) (_ :& _ :& zfs) (x :*: y)
= isoVar2 afa afb (:*:) unP
(gjoinBV afas zfas x)
(gjoinBV afbs zfbs y)
where
(afas, afbs) = splitRec afs
(zfas, zfbs) = splitRec zfs
unP (xx :*: yy) = (xx, yy)
{-# INLINE gjoinBV #-}
instance ( Reifies s W
, BVGroup s as i1 o1
, BVGroup s bs i2 o2
, cs ~ (as ++ bs)
, RecApplicative as
) => BVGroup s (i1 () ': i2 () ': cs) (i1 :+: i2) (o1 :+: o2) where
gsplitBV (afa :& afb :& afs) (zfa :& zfb :& zfs) xy =
case previewVar afa zf s1 xy of
Just x -> L1 $ gsplitBV afas zfas x
Nothing -> case previewVar afb zf s2 xy of
Just y -> R1 $ gsplitBV afbs zfbs y
Nothing -> error "Numeric.Backprop.gsplitBV: Internal error occurred"
where
zf = ZF $ \case
L1 xx -> L1 $ runZF zfa xx
R1 yy -> R1 $ runZF zfb yy
(afas, afbs) = splitRec afs
(zfas, zfbs) = splitRec zfs
{-# INLINE gsplitBV #-}
gjoinBV (afa :& afb :& afs) (zfa :& zfb :& zfs) = \case
L1 x -> liftOp1 afa (op1 (\xx -> (L1 xx, \case L1 d -> d; R1 _ -> runZF zfa xx)))
(gjoinBV afas zfas x)
R1 y -> liftOp1 afb (op1 (\yy -> (R1 yy, \case L1 _ -> runZF zfb yy; R1 d -> d)))
(gjoinBV afbs zfbs y)
where
(afas, afbs) = splitRec afs
(zfas, zfbs) = splitRec zfs
{-# INLINE gjoinBV #-}
splitBV
:: forall z f s as.
( Generic (z f)
, Generic (z (BVar s))
, BVGroup s as (Rep (z f)) (Rep (z (BVar s)))
, Reifies s W
)
=> AddFunc (Rep (z f) ())
-> Rec AddFunc as
-> ZeroFunc (z f)
-> Rec ZeroFunc as
-> BVar s (z f)
-> z (BVar s)
splitBV af afs zf zfs =
G.to
. gsplitBV afs zfs
. viewVar af zf (lens (from @(z f) @()) (const G.to))
{-# INLINE splitBV #-}
joinBV
:: forall z f s as.
( Generic (z f)
, Generic (z (BVar s))
, BVGroup s as (Rep (z f)) (Rep (z (BVar s)))
, Reifies s W
)
=> AddFunc (z f)
-> Rec AddFunc as
-> ZeroFunc (Rep (z f) ())
-> Rec ZeroFunc as
-> z (BVar s)
-> BVar s (z f)
joinBV af afs zf zfs =
viewVar af zf (lens G.to (const from))
. gjoinBV afs zfs
. from @(z (BVar s)) @()
{-# INLINE joinBV #-}