{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
module Numeric.Backprop.Explicit (
BVar, W, Backprop(..), ABP(..), NumBP(..)
, ZeroFunc(..), zfNum, zfNums, zeroFunc, zeroFuncs, zfFunctor
, AddFunc(..), afNum, afNums, addFunc, addFuncs
, OneFunc(..), ofNum, ofNums, oneFunc, oneFuncs, ofFunctor
, backprop, evalBP, gradBP, backpropWith
, backprop2, evalBP2, gradBP2, backpropWith2
, backpropN, evalBPN, gradBPN, backpropWithN, Every
, constVar, auto, coerceVar
, viewVar, setVar
, sequenceVar, collectVar
, previewVar, toListOfVar
, isoVar, isoVar2, isoVar3, isoVarN
, liftOp
, liftOp1, liftOp2, liftOp3
, Op(..)
, op0, opConst, idOp
, opConst'
, op1, op2, op3
, opCoerce, opTup, opIso, opIsoN, opLens
, noGrad1, noGrad
, Prod(..), pattern (:>), only, head'
, Tuple, pattern (::<), only_
, I(..)
, Reifies
) where
import Data.Bifunctor
import Data.Reflection
import Data.Type.Index
import Data.Type.Length
import Data.Type.Product
import Numeric.Backprop.Class
import Numeric.Backprop.Internal
import Numeric.Backprop.Op
import Type.Class.Higher
import Type.Class.Known
import Type.Class.Witness
zfNums :: (Every Num as, Known Length as) => Prod ZeroFunc as
zfNums = map1 (\i -> zfNum \\ every @_ @Num i) indices
zfFunctor :: (Backprop a, Functor f) => ZeroFunc (f a)
zfFunctor = ZF zeroFunctor
{-# INLINE zfFunctor #-}
afNums :: (Every Num as, Known Length as) => Prod AddFunc as
afNums = map1 (\i -> afNum \\ every @_ @Num i) indices
ofNums :: (Every Num as, Known Length as) => Prod OneFunc as
ofNums = map1 (\i -> ofNum \\ every @_ @Num i) indices
ofFunctor :: (Backprop a, Functor f) => OneFunc (f a)
ofFunctor = OF oneFunctor
{-# INLINE ofFunctor #-}
zeroFunc :: Backprop a => ZeroFunc a
zeroFunc = ZF zero
{-# INLINE zeroFunc #-}
addFunc :: Backprop a => AddFunc a
addFunc = AF add
{-# INLINE addFunc #-}
oneFunc :: Backprop a => OneFunc a
oneFunc = OF one
{-# INLINE oneFunc #-}
zeroFuncs :: (Every Backprop as, Known Length as) => Prod ZeroFunc as
zeroFuncs = map1 (\i -> zeroFunc \\ every @_ @Backprop i) indices
addFuncs :: (Every Backprop as, Known Length as) => Prod AddFunc as
addFuncs = map1 (\i -> addFunc \\ every @_ @Backprop i) indices
oneFuncs :: (Every Backprop as, Known Length as) => Prod OneFunc as
oneFuncs = map1 (\i -> oneFunc \\ every @_ @Backprop i) indices
auto :: a -> BVar s a
auto = constVar
{-# INLINE auto #-}
backpropWithN
:: Prod ZeroFunc as
-> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> (b -> b)
-> (b, Tuple as)
backpropWithN zfs f xs g = backpropN zfs (OF g) f xs
{-# INLINE backpropWithN #-}
backprop
:: ZeroFunc a
-> OneFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> (b, a)
backprop zfa ofb f = second (getI . head')
. backpropN (zfa :< Ø) ofb (f . head')
. only_
{-# INLINE backprop #-}
backpropWith
:: ZeroFunc a
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> (b -> b)
-> (b, a)
backpropWith zfa f x g = backprop zfa (OF g) f x
{-# INLINE backpropWith #-}
evalBP :: (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> b
evalBP f = evalBPN (f . head') . only_
{-# INLINE evalBP #-}
gradBP
:: ZeroFunc a
-> OneFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b)
-> a
-> a
gradBP zfa ofb f = snd . backprop zfa ofb f
{-# INLINE gradBP #-}
gradBPN
:: Prod ZeroFunc as
-> OneFunc b
-> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> Tuple as
gradBPN zfas ofb f = snd . backpropN zfas ofb f
{-# INLINE gradBPN #-}
backprop2
:: ZeroFunc a
-> ZeroFunc b
-> OneFunc c
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (c, (a, b))
backprop2 zfa zfb ofc f x y = second (\(dx ::< dy ::< Ø) -> (dx, dy)) $
backpropN (zfa :< zfb :< Ø) ofc
(\(x' :< y' :< Ø) -> f x' y')
(x ::< y ::< Ø)
{-# INLINE backprop2 #-}
backpropWith2
:: ZeroFunc a
-> ZeroFunc b
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (c -> c)
-> (c, (a, b))
backpropWith2 zfa zfb f x y g = backprop2 zfa zfb (OF g) f x y
{-# INLINE backpropWith2 #-}
evalBP2
:: (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> c
evalBP2 f x y = evalBPN (\(x' :< y' :< Ø) -> f x' y') (x ::< y ::< Ø)
{-# INLINE evalBP2 #-}
gradBP2
:: ZeroFunc a
-> ZeroFunc b
-> OneFunc c
-> (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c)
-> a
-> b
-> (a, b)
gradBP2 zfa zfb ofc f x = snd . backprop2 zfa zfb ofc f x
{-# INLINE gradBP2 #-}
isoVar
:: Reifies s W
=> AddFunc a
-> ZeroFunc b
-> (a -> b)
-> (b -> a)
-> BVar s a
-> BVar s b
isoVar af z f g = liftOp1 af z (opIso f g)
{-# INLINE isoVar #-}
isoVar2
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc c
-> (a -> b -> c)
-> (c -> (a, b))
-> BVar s a
-> BVar s b
-> BVar s c
isoVar2 afa afb z f g = liftOp2 afa afb z (opIso2 f g)
{-# INLINE isoVar2 #-}
isoVar3
:: Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> ZeroFunc d
-> (a -> b -> c -> d)
-> (d -> (a, b, c))
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
isoVar3 afa afb afc z f g = liftOp3 afa afb afc z (opIso3 f g)
{-# INLINE isoVar3 #-}
isoVarN
:: Reifies s W
=> Prod AddFunc as
-> ZeroFunc b
-> (Tuple as -> b)
-> (b -> Tuple as)
-> Prod (BVar s) as
-> BVar s b
isoVarN afs z f g = liftOp afs z (opIsoN f g)
{-# INLINE isoVarN #-}