{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Numeric.Backprop.Internal (
BVar
, W
, backpropN, evalBPN
, constVar
, liftOp, liftOp1, liftOp2, liftOp3
, viewVar, setVar, sequenceVar, collectVar, previewVar, toListOfVar
, coerceVar
, ZeroFunc(..), zfNum
, AddFunc(..), afNum
, OneFunc(..), ofNum
, debugSTN
, debugIR
) where
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Coerce
import Data.Foldable
import Data.Function
import Data.IORef
import Data.Kind
import Data.Maybe
import Data.Monoid hiding (Any(..))
import Data.Proxy
import Data.Reflection
import Data.Type.Conjunction
import Data.Type.Product hiding (toList)
import Data.Type.Util
import Data.Type.Vector hiding (itraverse)
import Data.Typeable
import GHC.Exts (Any)
import GHC.Generics
import Lens.Micro
import Numeric.Backprop.Op
import System.IO.Unsafe
import Type.Class.Higher
import Unsafe.Coerce
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
newtype ZeroFunc a = ZF { runZF :: a -> a }
newtype AddFunc a = AF { runAF :: a -> a -> a }
newtype OneFunc a = OF { runOF :: a -> a }
zfNum :: Num a => ZeroFunc a
zfNum = ZF (const 0)
{-# INLINE zfNum #-}
afNum :: Num a => AddFunc a
afNum = AF (+)
{-# INLINE afNum #-}
ofNum :: Num a => OneFunc a
ofNum = OF (const 1)
{-# INLINE ofNum #-}
data BVar s a = BV { _bvRef :: !(BRef s)
, _bvVal :: !a
}
deriving instance Typeable (BVar s a)
data BRef (s :: Type) = BRInp !Int
| BRIx !Int
| BRC
deriving (Generic, Show)
instance NFData (BRef s)
instance NFData a => NFData (BVar s a) where
rnf (BV r v) = force r `seq` force v `seq` ()
bvConst :: BVar s a -> Maybe a
bvConst (BV BRC !x) = Just x
bvConst _ = Nothing
{-# INLINE bvConst #-}
forceBVar :: BVar s a -> ()
forceBVar (BV r !_) = force r `seq` ()
{-# INLINE forceBVar #-}
data InpRef :: Type -> Type where
IR :: { _irIx :: !(BVar s b)
, _irAdd :: !(a -> b -> b)
}
-> InpRef a
forceInpRef :: InpRef a -> ()
forceInpRef (IR v !_) = forceBVar v `seq` ()
{-# INLINE forceInpRef #-}
debugIR :: InpRef a -> String
debugIR IR{..} = show (_bvRef _irIx)
data TapeNode :: Type -> Type where
TN :: { _tnInputs :: !(Prod InpRef as)
, _tnGrad :: !(a -> Tuple as)
}
-> TapeNode a
forceTapeNode :: TapeNode a -> ()
forceTapeNode (TN inps !_) = foldMap1 forceInpRef inps `seq` ()
{-# INLINE forceTapeNode #-}
data SomeTapeNode :: Type where
STN :: { _stnZero :: a
, _stnNode :: !(TapeNode a)
}
-> SomeTapeNode
debugSTN :: SomeTapeNode -> String
debugSTN (STN _ TN{..}) = show . foldMap1 ((:[]) . debugIR) $ _tnInputs
newtype W = W { wRef :: IORef (Int, [SomeTapeNode]) }
initWengert :: IO W
initWengert = W <$> newIORef (0,[])
{-# INLINE initWengert #-}
insertNode
:: TapeNode a
-> a
-> ZeroFunc a
-> W
-> IO (BVar s a)
insertNode tn !x zf !w = fmap ((`BV` x) . BRIx) . atomicModifyIORef' (wRef w) $ \(!n,!t) ->
let n' = n + 1
t' = STN (runZF zf x) tn : t
in forceTapeNode tn `seq` n' `seq` t' `seq` ((n', t'), n)
{-# INLINE insertNode #-}
constVar :: a -> BVar s a
constVar = BV BRC
{-# INLINE constVar #-}
liftOp_
:: forall s as b. Reifies s W
=> Prod AddFunc as
-> ZeroFunc b
-> Op as b
-> Prod (BVar s) as
-> IO (BVar s b)
liftOp_ afs z o !vs = case traverse1 (fmap I . bvConst) vs of
Just xs -> return $ constVar (evalOp o xs)
Nothing -> insertNode tn y z (reflect (Proxy @s))
where
(y,g) = runOpWith o (map1 (I . _bvVal) vs)
tn = TN { _tnInputs = map1 go (zipP afs vs)
, _tnGrad = g
}
go :: forall a. (AddFunc :&: BVar s) a -> InpRef a
go (af :&: (!v)) = forceBVar v `seq` IR v (runAF af)
{-# INLINE go #-}
{-# INLINE liftOp_ #-}
liftOp
:: forall as b s. Reifies s W
=> Prod AddFunc as
-> ZeroFunc b
-> Op as b
-> Prod (BVar s) as
-> BVar s b
liftOp afs z o !vs = unsafePerformIO $ liftOp_ afs z o vs
{-# INLINE liftOp #-}
liftOp1_
:: forall a b s. Reifies s W
=> AddFunc a
-> ZeroFunc b
-> Op '[a] b
-> BVar s a
-> IO (BVar s b)
liftOp1_ _ _ o (bvConst->Just x) = return . constVar . evalOp o $ (x ::< Ø)
liftOp1_ af z o v = forceBVar v `seq` insertNode tn y z (reflect (Proxy @s))
where
(y,g) = runOpWith o (_bvVal v ::< Ø)
tn = TN { _tnInputs = IR v (runAF af) :< Ø
, _tnGrad = g
}
{-# INLINE liftOp1_ #-}
liftOp1
:: forall a b s. Reifies s W
=> AddFunc a
-> ZeroFunc b
-> Op '[a] b
-> BVar s a
-> BVar s b
liftOp1 af z o !v = unsafePerformIO $ liftOp1_ af z o v
{-# INLINE liftOp1 #-}
liftOp2_
:: forall a b c s. Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc c
-> Op '[a,b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ _ _ _ o (bvConst->Just x) (bvConst->Just y)
= return . constVar . evalOp o $ x ::< y ::< Ø
liftOp2_ afa afb z o v u = forceBVar v
`seq` forceBVar u
`seq` insertNode tn y z (reflect (Proxy @s))
where
(y,g) = runOpWith o (_bvVal v ::< _bvVal u ::< Ø)
tn = TN { _tnInputs = IR v (runAF afa) :< IR u (runAF afb) :< Ø
, _tnGrad = g
}
{-# INLINE liftOp2_ #-}
liftOp2
:: forall a b c s. Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc c
-> Op '[a,b] c
-> BVar s a
-> BVar s b
-> BVar s c
liftOp2 afa afb z o !v !u = unsafePerformIO $ liftOp2_ afa afb z o v u
{-# INLINE liftOp2 #-}
liftOp3_
:: forall a b c d s. Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> ZeroFunc d
-> Op '[a,b,c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ _ _ _ _ o (bvConst->Just x) (bvConst->Just y) (bvConst->Just z)
= return . constVar . evalOp o $ x ::< y ::< z ::< Ø
liftOp3_ afa afb afc z o v u w = forceBVar v
`seq` forceBVar u
`seq` forceBVar w
`seq` insertNode tn y z (reflect (Proxy @s))
where
(y, g) = runOpWith o (_bvVal v ::< _bvVal u ::< _bvVal w ::< Ø)
tn = TN { _tnInputs = IR v (runAF afa)
:< IR u (runAF afb)
:< IR w (runAF afc)
:< Ø
, _tnGrad = g
}
{-# INLINE liftOp3_ #-}
liftOp3
:: forall a b c d s. Reifies s W
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> ZeroFunc d
-> Op '[a,b,c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
liftOp3 afa afb afc z o !v !u !w = unsafePerformIO $ liftOp3_ afa afb afc z o v u w
{-# INLINE liftOp3 #-}
viewVar_
:: forall a b s. Reifies s W
=> AddFunc a
-> ZeroFunc a
-> Lens' b a
-> BVar s b
-> IO (BVar s a)
viewVar_ af z l v = forceBVar v `seq` insertNode tn y z (reflect (Proxy @s))
where
y = _bvVal v ^. l
tn = TN { _tnInputs = IR v (over l . runAF af) :< Ø
, _tnGrad = only_
}
{-# INLINE viewVar_ #-}
viewVar
:: forall a b s. Reifies s W
=> AddFunc a
-> ZeroFunc a
-> Lens' b a
-> BVar s b
-> BVar s a
viewVar af z l !v = unsafePerformIO $ viewVar_ af z l v
{-# INLINE viewVar #-}
setVar_
:: forall a b s. Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc a
-> ZeroFunc b
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ afa afb za zb l w v = forceBVar v
`seq` forceBVar w
`seq` insertNode tn y zb (reflect (Proxy @s))
where
y = _bvVal v & l .~ _bvVal w
tn = TN { _tnInputs = IR w (runAF afa) :< IR v (runAF afb) :< Ø
, _tnGrad = \d -> let (dw,dv) = l (\x -> (x, runZF za x)) d
in dw ::< dv ::< Ø
}
{-# INLINE setVar_ #-}
setVar
:: forall a b s. Reifies s W
=> AddFunc a
-> AddFunc b
-> ZeroFunc a
-> ZeroFunc b
-> Lens' b a
-> BVar s a
-> BVar s b
-> BVar s b
setVar afa afb za zb l !w !v = unsafePerformIO $ setVar_ afa afb za zb l w v
{-# INLINE setVar #-}
sequenceVar
:: forall t a s. (Reifies s W, Traversable t)
=> AddFunc a
-> ZeroFunc a
-> BVar s (t a)
-> t (BVar s a)
sequenceVar af z !v = unsafePerformIO $ traverseVar' af z id traverse v
{-# INLINE sequenceVar #-}
collectVar_
:: forall t a s. (Reifies s W, Foldable t, Functor t)
=> AddFunc a
-> ZeroFunc a
-> ZeroFunc (t a)
-> t (BVar s a)
-> IO (BVar s (t a))
collectVar_ af z z' !vs = withV (toList vs) $ \(vVec :: Vec n (BVar s a)) -> do
let tn :: TapeNode (t a)
tn = TN
{ _tnInputs = vecToProd (vmap ((`IR` runAF af) . getI) vVec)
, _tnGrad = vecToProd
. zipVecList (\(I v) -> I . fromMaybe (runZF z (_bvVal v))) vVec
. toList
}
traverse_ (evaluate . forceBVar) vs
insertNode tn (_bvVal <$> vs) z' (reflect (Proxy @s))
{-# INLINE collectVar_ #-}
collectVar
:: forall t a s. (Reifies s W, Foldable t, Functor t)
=> AddFunc a
-> ZeroFunc a
-> ZeroFunc (t a)
-> t (BVar s a)
-> BVar s (t a)
collectVar af z z' !vs = unsafePerformIO $ collectVar_ af z z' vs
{-# INLINE collectVar #-}
traverseVar'
:: forall b a f s. (Reifies s W, Traversable f)
=> AddFunc a
-> ZeroFunc a
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' af z f t v = forceBVar v
`seq` itraverse go (f (_bvVal v))
where
go :: Int -> a -> IO (BVar s a)
go i y = insertNode tn y z (reflect (Proxy @s))
where
tn = TN { _tnInputs = IR v (over (ixt t i) . runAF af) :< Ø
, _tnGrad = only_
}
{-# INLINE go #-}
{-# INLINE traverseVar' #-}
previewVar
:: forall b a s. Reifies s W
=> AddFunc a
-> ZeroFunc a
-> Traversal' b a
-> BVar s b
-> Maybe (BVar s a)
previewVar af z t !v = unsafePerformIO $ traverseVar' af z (listToMaybe . toListOf t) t v
{-# INLINE previewVar #-}
toListOfVar
:: forall b a s. Reifies s W
=> AddFunc a
-> ZeroFunc a
-> Traversal' b a
-> BVar s b
-> [BVar s a]
toListOfVar af z t !v = unsafePerformIO $ traverseVar' af z (toListOf t) t v
{-# INLINE toListOfVar #-}
coerceVar
:: Coercible a b
=> BVar s a
-> BVar s b
coerceVar v@(BV r x) = forceBVar v `seq` BV r (coerce x)
data Runner s = R { _rDelta :: !(MV.MVector s Any)
, _rInputs :: !(MV.MVector s Any)
}
initRunner
:: (PrimMonad m, PrimState m ~ s)
=> (Int, [SomeTapeNode])
-> (Int, [Any])
-> m (Runner s)
initRunner (n, stns) (nx,xs) = do
delts <- MV.new n
for_ (zip [n-1,n-2..] stns) $ \(i, STN z (TN{..} :: TapeNode c)) ->
MV.write delts i $ unsafeCoerce z
inps <- MV.new nx
for_ (zip [0..] xs) . uncurry $ \i z ->
MV.write inps i z
return $ R delts inps
{-# INLINE initRunner #-}
gradRunner
:: forall m b s. (PrimMonad m, PrimState m ~ s)
=> b
-> Runner s
-> (Int, [SomeTapeNode])
-> m ()
gradRunner o R{..} (n,stns) = do
when (n > 0) $
MV.write _rDelta (n - 1) (unsafeCoerce o)
zipWithM_ go [n-1,n-2..] stns
where
go :: Int -> SomeTapeNode -> m ()
go i (STN _ TN{..}) = do
delt <- MV.read _rDelta i
let gs = _tnGrad (unsafeCoerce delt)
zipWithPM_ propagate _tnInputs gs
{-# INLINE go #-}
propagate :: forall x. InpRef x -> I x -> m ()
propagate (IR v (+*)) (I d) = case _bvRef v of
BRInp i -> flip (MV.modify _rInputs) i $
unsafeCoerce . (d +*) . unsafeCoerce
BRIx i -> flip (MV.modify _rDelta) i $
unsafeCoerce . (d +*) . unsafeCoerce
BRC -> return ()
{-# INLINE propagate #-}
{-# INLINE gradRunner #-}
backpropN
:: forall as b. ()
=> Prod ZeroFunc as
-> OneFunc b
-> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> (b, Tuple as)
backpropN zfs ofb f !xs = (y, g)
where
!(!tp@(!_,!_),!y) = unsafePerformIO $ fillWengert f xs
g :: Tuple as
g = runST $ do
r <- initRunner tp $ bimap getSum (`appEndo` [])
. fst
$ zipWithPM_ go zfs xs
gradRunner (runOF ofb y) r tp
delts <- toList <$> V.freeze (_rInputs r)
return . fromMaybe (error "backpropN") $
fillProd (\_ d -> I (unsafeCoerce d)) xs delts
where
go :: forall a. ZeroFunc a -> I a -> ((Sum Int, Endo [Any]),())
go zf (I x) = ((1, Endo (unsafeCoerce (runZF zf x) :)), ())
{-# INLINE go #-}
{-# INLINE backpropN #-}
evalBPN
:: forall as b. ()
=> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> b
evalBPN f = snd . unsafePerformIO . fillWengert f
{-# INLINE evalBPN #-}
fillWengert
:: forall as b. ()
=> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> IO ((Int, [SomeTapeNode]), b)
fillWengert f xs = do
w <- initWengert
o <- reify w $ \(Proxy :: Proxy s) -> do
let oVar = f (inpProd @s)
evaluate (forceBVar oVar)
return (_bvVal oVar)
t <- readIORef (wRef w)
return (t, o)
where
inpProd :: forall s. Prod (BVar s) as
inpProd = evalState (traverse1 (state . go . getI) xs) 0
where
go :: a -> Int -> (BVar s a, Int)
go x i = (BV (BRInp i) x, i + 1)
{-# INLINE go #-}
{-# INLINE inpProd #-}
{-# INLINE fillWengert #-}
instance (Num a, Reifies s W) => Num (BVar s a) where
(+) = liftOp2 afNum afNum zfNum (+.)
{-# INLINE (+) #-}
(-) = liftOp2 afNum afNum zfNum (-.)
{-# INLINE (-) #-}
(*) = liftOp2 afNum afNum zfNum (*.)
{-# INLINE (*) #-}
negate = liftOp1 afNum zfNum negateOp
{-# INLINE negate #-}
signum = liftOp1 afNum zfNum signumOp
{-# INLINE signum #-}
abs = liftOp1 afNum zfNum absOp
{-# INLINE abs #-}
fromInteger = constVar . fromInteger
{-# INLINE fromInteger #-}
instance (Fractional a, Reifies s W) => Fractional (BVar s a) where
(/) = liftOp2 afNum afNum zfNum (/.)
{-# INLINE (/) #-}
recip = liftOp1 afNum zfNum recipOp
{-# INLINE recip #-}
fromRational = constVar . fromRational
{-# INLINE fromRational #-}
instance (Floating a, Reifies s W) => Floating (BVar s a) where
pi = constVar pi
{-# INLINE pi #-}
exp = liftOp1 afNum zfNum expOp
{-# INLINE exp #-}
log = liftOp1 afNum zfNum logOp
{-# INLINE log #-}
sqrt = liftOp1 afNum zfNum sqrtOp
{-# INLINE sqrt #-}
(**) = liftOp2 afNum afNum zfNum (**.)
{-# INLINE (**) #-}
logBase = liftOp2 afNum afNum zfNum logBaseOp
{-# INLINE logBase #-}
sin = liftOp1 afNum zfNum sinOp
{-# INLINE sin #-}
cos = liftOp1 afNum zfNum cosOp
{-# INLINE cos #-}
tan = liftOp1 afNum zfNum tanOp
{-# INLINE tan #-}
asin = liftOp1 afNum zfNum asinOp
{-# INLINE asin #-}
acos = liftOp1 afNum zfNum acosOp
{-# INLINE acos #-}
atan = liftOp1 afNum zfNum atanOp
{-# INLINE atan #-}
sinh = liftOp1 afNum zfNum sinhOp
{-# INLINE sinh #-}
cosh = liftOp1 afNum zfNum coshOp
{-# INLINE cosh #-}
tanh = liftOp1 afNum zfNum tanhOp
{-# INLINE tanh #-}
asinh = liftOp1 afNum zfNum asinhOp
{-# INLINE asinh #-}
acosh = liftOp1 afNum zfNum acoshOp
{-# INLINE acosh #-}
atanh = liftOp1 afNum zfNum atanhOp
{-# INLINE atanh #-}
instance Eq a => Eq (BVar s a) where
(==) = (==) `on` _bvVal
(/=) = (/=) `on` _bvVal
instance Ord a => Ord (BVar s a) where
compare = compare `on` _bvVal
(<) = (<) `on` _bvVal
(<=) = (<=) `on` _bvVal
(>) = (>) `on` _bvVal
(>=) = (>=) `on` _bvVal
itraverse
:: forall t a b f. (Traversable t, Monad f)
=> (Int -> a -> f b) -> t a -> f (t b)
itraverse f xs = evalStateT (traverse (StateT . go) xs) 0
where
go :: a -> Int -> f (b, Int)
go x i = (,i+1) <$> f i x
{-# INLINE itraverse #-}
ixi :: Int -> Lens' [a] a
ixi _ _ [] = error "ixi"
ixi 0 f (x:xs) = (:xs) <$> f x
ixi n f (x:xs) = (x:) <$> ixi (n - 1) f xs
{-# INLINE ixi #-}
ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt t i f xs = stuff <$> ixi i f contents
where
contents = xs ^.. t
stuff = evalState (traverseOf t (state . const go) xs)
where
go :: [a] -> (a, [a])
go [] = error "asList"
go (y:ys) = (y, ys)
{-# INLINE ixt #-}