{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.Backprop.Internal (
BVar
, W
, backpropWithN, evalBPN
, constVar
, liftOp, liftOp1, liftOp2, liftOp3
, viewVar, setVar, sequenceVar, collectVar, previewVar, toListOfVar
, coerceVar
, ZeroFunc(..), zfNum, zeroFunc
, AddFunc(..), afNum, addFunc
, OneFunc(..), ofNum, oneFunc
, debugSTN
, debugIR
) where
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Coerce
import Data.Foldable
import Data.Function
import Data.Functor.Identity
import Data.IORef
import Data.Kind
import Data.Maybe
import Data.Monoid hiding (Any(..))
import Data.Proxy
import Data.Reflection
import Data.Type.Util
import Data.Typeable
import Data.Vinyl.Core
import GHC.Exts (Any)
import GHC.Generics as G
import Lens.Micro
import Lens.Micro.Extras
import Numeric.Backprop.Class
import Numeric.Backprop.Op
import System.IO.Unsafe
import Unsafe.Coerce
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import qualified Data.Vinyl.Recursive as VR
import qualified Data.Vinyl.XRec as X
newtype ZeroFunc a = ZF { runZF :: a -> a }
newtype AddFunc a = AF { runAF :: a -> a -> a }
newtype OneFunc a = OF { runOF :: a -> a }
zfNum :: Num a => ZeroFunc a
zfNum = ZF (const 0)
{-# INLINE zfNum #-}
afNum :: Num a => AddFunc a
afNum = AF (+)
{-# INLINE afNum #-}
ofNum :: Num a => OneFunc a
ofNum = OF (const 1)
{-# INLINE ofNum #-}
data BVar s a = BV { _bvRef :: !(BRef s)
, _bvVal :: !a
}
deriving instance Typeable (BVar s a)
instance X.IsoHKD (BVar s) a
data BRef (s :: Type) = BRInp !Int
| BRIx !Int
| BRC
deriving (Generic, Show)
instance NFData (BRef s)
instance NFData a => NFData (BVar s a) where
rnf (BV r v) = force r `seq` force v `seq` ()
bvConst :: BVar s a -> Maybe a
bvConst (BV BRC !x) = Just x
bvConst _ = Nothing
{-# INLINE bvConst #-}
forceBVar :: BVar s a -> ()
forceBVar (BV r !_) = force r `seq` ()
{-# INLINE forceBVar #-}
data InpRef :: Type -> Type where
IR :: { _irIx :: !(BVar s b)
, _irAdd :: !(a -> b -> b)
, _irEmbed :: !(a -> b)
}
-> InpRef a
forceInpRef :: InpRef a -> ()
forceInpRef (IR v !_ !_) = forceBVar v `seq` ()
{-# INLINE forceInpRef #-}
debugIR :: InpRef a -> String
debugIR IR{..} = show (_bvRef _irIx)
data TapeNode :: Type -> Type where
TN :: { _tnInputs :: !(Rec InpRef as)
, _tnGrad :: !(a -> Rec Identity as)
}
-> TapeNode a
forceTapeNode :: TapeNode a -> ()
forceTapeNode (TN inps !_) = VR.rfoldMap forceInpRef inps `seq` ()
{-# INLINE forceTapeNode #-}
data SomeTapeNode :: Type where
STN :: { _stnNode :: !(TapeNode a)
}
-> SomeTapeNode
forceSomeTapeNode :: SomeTapeNode -> ()
forceSomeTapeNode (STN n) = forceTapeNode n
debugSTN :: SomeTapeNode -> String
debugSTN (STN TN{..}) = show . VR.rfoldMap ((:[]) . debugIR) $ _tnInputs
newtype W = W { wRef :: IORef (Int, [SomeTapeNode]) }
initWengert :: IO W
initWengert = W <$> newIORef (0,[])
{-# INLINE initWengert #-}
insertNode
:: TapeNode a
-> a
-> W
-> IO (BVar s a)
insertNode tn !x !w = fmap ((`BV` x) . BRIx) . atomicModifyIORef' (wRef w) $ \(!n,!t) ->
let n' = n + 1
t' = STN tn : t
in forceTapeNode tn `seq` n' `seq` t' `seq` ((n', t'), n)
{-# INLINE insertNode #-}
constVar :: a -> BVar s a
constVar = BV BRC
{-# INLINE constVar #-}
liftOp_
:: forall s as b. Reifies s W
=> Rec AddFunc as
-> Op as b
-> Rec (BVar s) as
-> IO (BVar s b)
liftOp_ afs o !vs = case rtraverse (fmap Identity . bvConst) vs of
Just xs -> return $ constVar (evalOp o xs)
Nothing -> insertNode tn y (reflect (Proxy @s))
where
(y,g) = runOpWith o (VR.rmap (Identity . _bvVal) vs)
tn = TN { _tnInputs = VR.rzipWith go afs vs
, _tnGrad = g
}
go :: forall a. AddFunc a -> BVar s a -> InpRef a
go af !v = forceBVar v `seq` IR v (runAF af) id
{-# INLINE go #-}
{-# INLINE liftOp_ #-}
liftOp
:: forall as b s. Reifies s W
=> Rec AddFunc as
-> Op as b
-> Rec (BVar s) as
-> BVar s b
liftOp afs o !vs = unsafePerformIO $ liftOp_ afs o vs
{-# INLINE liftOp #-}
liftOp1_
:: forall a b s. Reifies s W
=> AddFunc a
-> Op '[a] b
-> BVar s a
-> IO (BVar s b)
liftOp1_ _ o (bvConst->Just x) = return . constVar . evalOp o $ (Identity x :& RNil)
liftOp1_ af o v = forceBVar v `seq` insertNode tn y (reflect (Proxy @s))
where
(y,g) = runOpWith o (Identity (_bvVal v) :& RNil)
tn = TN { _tnInputs = IR v (runAF af) id :& RNil
, _tnGrad = g
}
{-# INLINE liftOp1_ #-}
liftOp1
:: forall a b s. Reifies s W
=> AddFunc a
-> Op '[a] b
-> BVar s a
-> BVar s b
liftOp1 af o !v = unsafePerformIO $ liftOp1_ af o v
{-# INLINE liftOp1 #-}
liftOp2_
:: forall a b c s. Reifies s W
=> AddFunc a
-> AddFunc b
-> Op '[a,b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ _ _ o (bvConst->Just x) (bvConst->Just y)
= return . constVar . evalOp o $ Identity x :& Identity y :& RNil
liftOp2_ afa afb o v u = forceBVar v
`seq` forceBVar u
`seq` insertNode tn y (reflect (Proxy @s))
where
(y,g) = runOpWith o $ Identity (_bvVal v)
:& Identity (_bvVal u)
:& RNil
tn = TN { _tnInputs = IR v (runAF afa) id :& IR u (runAF afb) id :& RNil
, _tnGrad = g
}
{-# INLINE liftOp2_ #-}
liftOp2
:: forall a b c s. Reifies s W
=> AddFunc a
-> AddFunc b
-> Op '[a,b] c
-> BVar s a
-> BVar s b
-> BVar s c
liftOp2 afa afb o !v !u = unsafePerformIO $ liftOp2_ afa afb o v u
{-# INLINE liftOp2 #-}
liftOp3_
:: forall a b c d s. Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a,b,c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ _ _ _ o (bvConst->Just x) (bvConst->Just y) (bvConst->Just z)
= return . constVar . evalOp o $ Identity x
:& Identity y
:& Identity z
:& RNil
liftOp3_ afa afb afc o v u w = forceBVar v
`seq` forceBVar u
`seq` forceBVar w
`seq` insertNode tn y (reflect (Proxy @s))
where
(y, g) = runOpWith o $ Identity (_bvVal v)
:& Identity (_bvVal u)
:& Identity (_bvVal w)
:& RNil
tn = TN { _tnInputs = IR v (runAF afa) id
:& IR u (runAF afb) id
:& IR w (runAF afc) id
:& RNil
, _tnGrad = g
}
{-# INLINE liftOp3_ #-}
liftOp3
:: forall a b c d s. Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a,b,c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
liftOp3 afa afb afc o !v !u !w = unsafePerformIO $ liftOp3_ afa afb afc o v u w
{-# INLINE liftOp3 #-}
viewVar_
:: forall a b s. Reifies s W
=> AddFunc a
-> ZeroFunc b
-> Lens' b a
-> BVar s b
-> IO (BVar s a)
viewVar_ af z l v = forceBVar v `seq` insertNode tn y (reflect (Proxy @s))
where
x = _bvVal v
y = x ^. l
tn = TN { _tnInputs = IR v (over l . runAF af) (\g -> set l g (runZF z x))
:& RNil
, _tnGrad = (:& RNil) . Identity
}
{-# INLINE viewVar_ #-}
viewVar
:: forall a b s. Reifies s W
=> AddFunc a
-> ZeroFunc b
-> Lens' b a
-> BVar s b
-> BVar s a
viewVar af z l !v = unsafePerformIO $ viewVar_ af z l v
{-# INLINE viewVar #-}
setVar_
:: forall a b s. Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ afa afb za l w v = forceBVar v
`seq` forceBVar w
`seq` insertNode tn y (reflect (Proxy @s))
where
y = _bvVal v & l .~ _bvVal w
tn = TN { _tnInputs = IR w (runAF afa) id
:& IR v (runAF afb) id
:& RNil
, _tnGrad = \d -> let (dw,dv) = l (\x -> (x, runZF za x)) d
in Identity dw :& Identity dv :& RNil
}
{-# INLINE setVar_ #-}
setVar
:: forall a b s. Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> BVar s b
setVar afa afb za l !w !v = unsafePerformIO $ setVar_ afa afb za l w v
{-# INLINE setVar #-}
sequenceVar
:: forall t a s. (Reifies s W, Traversable t)
=> AddFunc a
-> ZeroFunc a
-> BVar s (t a)
-> t (BVar s a)
sequenceVar af z !v = unsafePerformIO $
traverseVar' af (ZF (fmap (runZF z))) id traverse v
{-# INLINE sequenceVar #-}
collectVar_
:: forall t a s. (Reifies s W, Foldable t, Functor t)
=> AddFunc a
-> ZeroFunc a
-> t (BVar s a)
-> IO (BVar s (t a))
collectVar_ af z !vs = withVec (toList vs) $ \(vVec :: VecT n (BVar s) a) -> do
let tn :: TapeNode (t a)
tn = TN
{ _tnInputs = vecToRec (vmap (\v -> IR v (runAF af) id) vVec)
, _tnGrad = vecToRec
. zipVecList (\v -> Identity . fromMaybe (runZF z (_bvVal v))) vVec
. toList
}
traverse_ (evaluate . forceBVar) vs
insertNode tn (_bvVal <$> vs) (reflect (Proxy @s))
{-# INLINE collectVar_ #-}
collectVar
:: forall t a s. (Reifies s W, Foldable t, Functor t)
=> AddFunc a
-> ZeroFunc a
-> t (BVar s a)
-> BVar s (t a)
collectVar af z !vs = unsafePerformIO $ collectVar_ af z vs
{-# INLINE collectVar #-}
traverseVar'
:: forall b a f s. (Reifies s W, Traversable f)
=> AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' af z f t v = forceBVar v
`seq` itraverse go (f x)
where
x = _bvVal v
go :: Int -> a -> IO (BVar s a)
go i y = insertNode tn y (reflect (Proxy @s))
where
tn = TN { _tnInputs = IR v (over (ixt t i) . runAF af)
(\g -> set (ixt t i) g (runZF z x))
:& RNil
, _tnGrad = (:& RNil) . Identity
}
{-# INLINE go #-}
{-# INLINE traverseVar' #-}
previewVar
:: forall b a s. Reifies s W
=> AddFunc a
-> ZeroFunc b
-> Traversal' b a
-> BVar s b
-> Maybe (BVar s a)
previewVar af z t !v = unsafePerformIO $
traverseVar' af z (preview t) t v
{-# INLINE previewVar #-}
toListOfVar
:: forall b a s. Reifies s W
=> AddFunc a
-> ZeroFunc b
-> Traversal' b a
-> BVar s b
-> [BVar s a]
toListOfVar af z t !v = unsafePerformIO $
traverseVar' af z (toListOf t) t v
{-# INLINE toListOfVar #-}
coerceVar
:: Coercible a b
=> BVar s a
-> BVar s b
coerceVar v@(BV r x) = forceBVar v `seq` BV r (coerce x)
data Runner s = R { _rDelta :: !(MV.MVector s (Maybe Any))
, _rInputs :: !(MV.MVector s (Maybe Any))
}
initRunner
:: (Int, [SomeTapeNode])
-> (Int, [Maybe Any])
-> ST s (Runner s)
initRunner (n, stns) (nx,xs) = do
delts <- MV.new n
for_ (zip [n-1,n-2..] stns) $ \(i, STN (TN{..} :: TapeNode c)) ->
MV.write delts i $ unsafeCoerce (Nothing @c)
inps <- MV.new nx
for_ (zip [0..] xs) . uncurry $ \i z ->
MV.write inps i z
return $ R delts inps
{-# INLINE initRunner #-}
gradRunner
:: forall b s. ()
=> b
-> Runner s
-> (Int, [SomeTapeNode])
-> ST s ()
gradRunner o R{..} (n,stns) = do
when (n > 0) $
MV.write _rDelta (n - 1) (unsafeCoerce (Just o))
zipWithM_ go [n-1,n-2..] stns
where
go :: Int -> SomeTapeNode -> ST s ()
go i (STN (TN{..} :: TapeNode c)) = do
delt <- MV.read _rDelta i
forM_ delt $ \d -> do
let gs = _tnGrad (unsafeCoerce d)
rzipWithM_ propagate _tnInputs gs
{-# INLINE go #-}
propagate :: forall x. InpRef x -> Identity x -> ST s ()
propagate (IR v (+*) e) (Identity d) = case _bvRef v of
BRInp i -> flip (MV.modify _rInputs) i $
unsafeCoerce . bumpMaybe d (+*) e . unsafeCoerce
BRIx i -> flip (MV.modify _rDelta) i $
unsafeCoerce . bumpMaybe d (+*) e . unsafeCoerce
BRC -> return ()
{-# INLINE propagate #-}
{-# INLINE gradRunner #-}
bumpMaybe
:: a
-> (a -> b -> b)
-> (a -> b)
-> Maybe b
-> Maybe b
bumpMaybe x (+*) e = \case
Nothing -> Just (e x)
Just y -> Just (x +* y)
{-# INLINE bumpMaybe #-}
seqEither :: Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither e@(Left !_) = e
seqEither e@(Right (!_,foldMap forceSomeTapeNode->(!_))) = e
{-# INLINE seqEither #-}
backpropWithN
:: forall as b. ()
=> Rec ZeroFunc as
-> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> (b, b -> Rec Identity as)
backpropWithN zfs f !xs = (y, g')
where
!(seqEither->(!tp0),!y) = unsafePerformIO $ fillWengert f xs
g' :: b -> Rec Identity as
g' = case tp0 of
Left i -> setInput i
Right tp -> g tp
{-# INLINE g' #-}
g :: (Int, [SomeTapeNode]) -> b -> Rec Identity as
g tp o = runST $ do
r <- initRunner tp . bimap getSum (`appEndo` [])
. VR.rfoldMap go
$ xs
gradRunner o r tp
delts <- toList <$> V.freeze (_rInputs r)
return . fromMaybe (internalError "backpropN") $
fillRec (\z -> maybe z (Identity . unsafeCoerce))
(VR.rzipWith (fmap . runZF) zfs xs)
delts
where
go :: forall a. Identity a -> (Sum Int, Endo [Maybe Any])
go _ = (1, Endo (unsafeCoerce (Nothing @a) :))
{-# INLINE go #-}
setInput :: Int -> b -> Rec Identity as
setInput !i !x = go zfs xs 0
where
go :: Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go = \case
RNil -> \_ _ -> RNil
z :& zs -> \case
q :& qs -> \(!j) ->
if j == i
then Identity (unsafeCoerce x) :& VR.rzipWith coerce zs qs
else coerce z q :& go zs qs (j + 1)
{-# INLINE setInput #-}
{-# INLINE backpropWithN #-}
evalBPN
:: forall as b. ()
=> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> b
evalBPN f = snd . unsafePerformIO . fillWengert f
{-# INLINE evalBPN #-}
fillWengert
:: forall as b. ()
=> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert f xs = do
w <- initWengert
(i, o) <- reify w $ \(Proxy :: Proxy s) -> do
let oVar = f (inpRec @s)
evaluate (forceBVar oVar)
let isInput = case _bvRef oVar of
BRInp i -> Just i
_ -> Nothing
pure (isInput, _bvVal oVar)
t <- case i of
Nothing -> Right <$> readIORef (wRef w)
Just i' -> pure $ Left i'
pure (t, o)
where
inpRec :: forall s. Rec (BVar s) as
inpRec = evalState (rtraverse (state . go . runIdentity) xs) 0
where
go :: a -> Int -> (BVar s a, Int)
go x i = (BV (BRInp i) x, i + 1)
{-# INLINE go #-}
{-# INLINE inpRec #-}
{-# INLINE fillWengert #-}
instance (Num a, Reifies s W) => Num (BVar s a) where
(+) = liftOp2 afNum afNum (+.)
{-# INLINE (+) #-}
(-) = liftOp2 afNum afNum (-.)
{-# INLINE (-) #-}
(*) = liftOp2 afNum afNum (*.)
{-# INLINE (*) #-}
negate = liftOp1 afNum negateOp
{-# INLINE negate #-}
signum = liftOp1 afNum signumOp
{-# INLINE signum #-}
abs = liftOp1 afNum absOp
{-# INLINE abs #-}
fromInteger = constVar . fromInteger
{-# INLINE fromInteger #-}
instance (Fractional a, Reifies s W) => Fractional (BVar s a) where
(/) = liftOp2 afNum afNum (/.)
{-# INLINE (/) #-}
recip = liftOp1 afNum recipOp
{-# INLINE recip #-}
fromRational = constVar . fromRational
{-# INLINE fromRational #-}
instance (Floating a, Reifies s W) => Floating (BVar s a) where
pi = constVar pi
{-# INLINE pi #-}
exp = liftOp1 afNum expOp
{-# INLINE exp #-}
log = liftOp1 afNum logOp
{-# INLINE log #-}
sqrt = liftOp1 afNum sqrtOp
{-# INLINE sqrt #-}
(**) = liftOp2 afNum afNum (**.)
{-# INLINE (**) #-}
logBase = liftOp2 afNum afNum logBaseOp
{-# INLINE logBase #-}
sin = liftOp1 afNum sinOp
{-# INLINE sin #-}
cos = liftOp1 afNum cosOp
{-# INLINE cos #-}
tan = liftOp1 afNum tanOp
{-# INLINE tan #-}
asin = liftOp1 afNum asinOp
{-# INLINE asin #-}
acos = liftOp1 afNum acosOp
{-# INLINE acos #-}
atan = liftOp1 afNum atanOp
{-# INLINE atan #-}
sinh = liftOp1 afNum sinhOp
{-# INLINE sinh #-}
cosh = liftOp1 afNum coshOp
{-# INLINE cosh #-}
tanh = liftOp1 afNum tanhOp
{-# INLINE tanh #-}
asinh = liftOp1 afNum asinhOp
{-# INLINE asinh #-}
acosh = liftOp1 afNum acoshOp
{-# INLINE acosh #-}
atanh = liftOp1 afNum atanhOp
{-# INLINE atanh #-}
instance Eq a => Eq (BVar s a) where
(==) = (==) `on` _bvVal
(/=) = (/=) `on` _bvVal
instance Ord a => Ord (BVar s a) where
compare = compare `on` _bvVal
(<) = (<) `on` _bvVal
(<=) = (<=) `on` _bvVal
(>) = (>) `on` _bvVal
(>=) = (>=) `on` _bvVal
itraverse
:: forall t a b f. (Traversable t, Monad f)
=> (Int -> a -> f b) -> t a -> f (t b)
itraverse f xs = evalStateT (traverse (StateT . go) xs) 0
where
go :: a -> Int -> f (b, Int)
go x i = (,i+1) <$> f i x
{-# INLINE itraverse #-}
ixi :: Int -> Lens' [a] a
ixi _ _ [] = internalError "ixi"
ixi 0 f (x:xs) = (:xs) <$> f x
ixi n f (x:xs) = (x:) <$> ixi (n - 1) f xs
{-# INLINE ixi #-}
ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt t i f xs = stuff <$> ixi i f contents
where
contents = xs ^.. t
stuff = evalState (traverseOf t (state . const go) xs)
where
go :: [a] -> (a, [a])
go [] = internalError "ixt"
go (y:ys) = (y, ys)
{-# INLINE ixt #-}
instance (Backprop a, Reifies s W) => Backprop (BVar s a) where
zero = liftOp1 addFunc . op1 $ \x -> (zero x, zero)
{-# INLINE zero #-}
add = liftOp2 addFunc addFunc . op2 $ \x y ->
( add x y
, \d -> (d, d)
)
{-# INLINE add #-}
one = liftOp1 addFunc . op1 $ \x -> (one x, zero)
{-# INLINE one #-}
zeroFunc :: Backprop a => ZeroFunc a
zeroFunc = ZF zero
{-# INLINE zeroFunc #-}
addFunc :: Backprop a => AddFunc a
addFunc = AF add
{-# INLINE addFunc #-}
oneFunc :: Backprop a => OneFunc a
oneFunc = OF one
{-# INLINE oneFunc #-}
internalError :: String -> a
internalError m = errorWithoutStackTrace $
"Numeric.Backprop.Internal." ++ m ++ ": unexpected shape involved in gradient computation"