module Numeric.Backprop.Internal (
BVar
, W
, backpropN, evalBPN
, constVar
, liftOp, liftOp1, liftOp2, liftOp3
, viewVar, setVar, sequenceVar, collectVar, previewVar, toListOfVar
, debugSTN
, debugIR
) where
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Foldable
import Data.IORef
import Data.Kind
import Data.Maybe
import Data.Monoid hiding (Any(..))
import Data.Proxy
import Data.Reflection
import Data.Type.Index
import Data.Type.Product hiding (toList)
import Data.Type.Util
import Data.Type.Vector hiding (itraverse)
import GHC.Exts (Any)
import GHC.Generics
import Lens.Micro
import Numeric.Backprop.Op
import System.IO.Unsafe
import Type.Class.Higher
import Type.Class.Witness
import Unsafe.Coerce
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
data BVar s a = BV { _bvRef :: !(BRef s)
, _bvVal :: !a
}
data BRef (s :: Type) = BRInp !Int
| BRIx !Int
| BRC
deriving (Generic, Show)
instance NFData (BRef s)
instance NFData a => NFData (BVar s a) where
rnf (BV r v) = force r `seq` force v `seq` ()
bvConst :: BVar s a -> Maybe a
bvConst (BV BRC !x) = Just x
bvConst _ = Nothing
forceBVar :: BVar s a -> ()
forceBVar (BV r !_) = force r `seq` ()
data InpRef :: Type -> Type where
IR :: { _irIx :: !(BVar s b)
, _irUpd :: !(Lens' b a)
, _irAdd :: !(a -> a -> a)
}
-> InpRef a
forceInpRef :: InpRef a -> ()
forceInpRef (IR v !_ !_) = forceBVar v `seq` ()
debugIR :: InpRef a -> String
debugIR IR{..} = show (_bvRef _irIx)
data TapeNode :: Type -> Type where
TN :: { _tnInputs :: !(Prod InpRef as)
, _tnGrad :: !(a -> Tuple as)
}
-> TapeNode a
forceTapeNode :: TapeNode a -> ()
forceTapeNode (TN inps !_) = foldMap1 forceInpRef inps `seq` ()
data SomeTapeNode :: Type where
STN :: forall a. Num a
=> !(TapeNode a)
-> SomeTapeNode
forceSomeTapeNode :: SomeTapeNode -> ()
forceSomeTapeNode (STN tn) = forceTapeNode tn `seq` ()
debugSTN :: SomeTapeNode -> String
debugSTN (STN TN{..}) = show . foldMap1 ((:[]) . debugIR) $ _tnInputs
newtype W = W { wRef :: IORef (Int, [SomeTapeNode]) }
initWengert :: IO W
initWengert = W <$> newIORef (0,[])
insertNode
:: Num a
=> TapeNode a
-> a
-> W
-> IO (BVar s a)
insertNode tn !x !w = fmap ((`BV` x) . BRIx) . atomicModifyIORef' (wRef w) $ \(!n,!t) ->
let n' = n + 1
t' = STN tn:t
in forceTapeNode tn `seq` n' `seq` t' `seq` ((n', t'), n)
constVar :: a -> BVar s a
constVar = BV BRC
liftOp_
:: forall s as b. (Reifies s W, Num b, Every Num as)
=> Op as b
-> Prod (BVar s) as
-> IO (BVar s b)
liftOp_ o !vs = case traverse1 (fmap I . bvConst) vs of
Just xs -> return $ constVar (evalOp o xs)
Nothing -> insertNode tn y (reflect (Proxy @s))
where
(y,g) = runOpWith o (map1 (I . _bvVal) vs)
tn = TN { _tnInputs = imap1 go vs
, _tnGrad = g
}
go :: forall a. Index as a -> BVar s a -> InpRef a
go i !v = forceBVar v `seq` (IR v id (+) \\ every @_ @Num i)
liftOp
:: forall s as b. (Reifies s W, Num b, Every Num as)
=> Op as b
-> Prod (BVar s) as
-> BVar s b
liftOp o !vs = unsafePerformIO $ liftOp_ o vs
liftOp1_
:: forall s a b. (Reifies s W, Num a, Num b)
=> Op '[a] b
-> BVar s a
-> IO (BVar s b)
liftOp1_ o (bvConst->Just x) = return . constVar . evalOp o $ (x ::< Ø)
liftOp1_ o v = forceBVar v `seq` insertNode tn y (reflect (Proxy @s))
where
(y,g) = runOpWith o (_bvVal v ::< Ø)
tn = TN { _tnInputs = IR v id (+) :< Ø
, _tnGrad = g
}
liftOp1
:: forall s a b. (Reifies s W, Num a, Num b)
=> Op '[a] b
-> BVar s a
-> BVar s b
liftOp1 o !v = unsafePerformIO $ liftOp1_ o v
liftOp2_
:: forall s a b c. (Reifies s W, Num a, Num b, Num c)
=> Op '[a,b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ o (bvConst->Just x) (bvConst->Just y) = return . constVar . evalOp o $ x ::< y ::< Ø
liftOp2_ o v u = forceBVar v
`seq` forceBVar u
`seq` insertNode tn y (reflect (Proxy @s))
where
(y,g) = runOpWith o (_bvVal v ::< _bvVal u ::< Ø)
tn = TN { _tnInputs = IR v id (+) :< IR u id (+) :< Ø
, _tnGrad = g
}
liftOp2
:: forall s a b c. (Reifies s W, Num a, Num b, Num c)
=> Op '[a,b] c
-> BVar s a
-> BVar s b
-> BVar s c
liftOp2 o !v !u = unsafePerformIO $ liftOp2_ o v u
liftOp3_
:: forall s a b c d. (Reifies s W, Num a, Num b, Num c, Num d)
=> Op '[a,b,c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ o (bvConst->Just x) (bvConst->Just y) (bvConst->Just z)
= return . constVar . evalOp o $ x ::< y ::< z ::< Ø
liftOp3_ o v u w = forceBVar v
`seq` forceBVar u
`seq` forceBVar w
`seq` insertNode tn y (reflect (Proxy @s))
where
(y, g) = runOpWith o (_bvVal v ::< _bvVal u ::< _bvVal w ::< Ø)
tn = TN { _tnInputs = IR v id (+) :< IR u id (+) :< IR w id (+) :< Ø
, _tnGrad = g
}
liftOp3
:: forall s a b c d. (Reifies s W, Num a, Num b, Num c, Num d)
=> Op '[a,b,c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
liftOp3 o !v !u !w = unsafePerformIO $ liftOp3_ o v u w
viewVar_
:: forall a b s. (Reifies s W, Num a)
=> Lens' b a
-> BVar s b
-> IO (BVar s a)
viewVar_ l v = forceBVar v `seq` insertNode tn y (reflect (Proxy @s))
where
y = _bvVal v ^. l
tn = TN { _tnInputs = IR v l (+) :< Ø
, _tnGrad = only_
}
viewVar
:: forall a b s. (Reifies s W, Num a)
=> Lens' b a
-> BVar s b
-> BVar s a
viewVar l !v = unsafePerformIO $ viewVar_ l v
setVar_
:: forall a b s. (Reifies s W, Num a, Num b)
=> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ l w v = forceBVar v
`seq` forceBVar w
`seq` insertNode tn y (reflect (Proxy @s))
where
y = _bvVal v & l .~ _bvVal w
tn = TN { _tnInputs = IR w id (+) :< IR v id (+) :< Ø
, _tnGrad = \d -> let (dw,dv) = l (,0) d
in dw ::< dv ::< Ø
}
setVar
:: forall a b s. (Reifies s W, Num a, Num b)
=> Lens' b a
-> BVar s a
-> BVar s b
-> BVar s b
setVar l !w !v = unsafePerformIO $ setVar_ l w v
sequenceVar
:: forall t a s. (Reifies s W, Traversable t, Num a)
=> BVar s (t a)
-> t (BVar s a)
sequenceVar !v = unsafePerformIO $ traverseVar' id traverse v
collectVar_
:: forall a t s. (Reifies s W, Foldable t, Functor t, Num (t a), Num a)
=> t (BVar s a)
-> IO (BVar s (t a))
collectVar_ !vs = withV (toList vs) $ \(vVec :: Vec n (BVar s a)) -> do
let tn :: TapeNode (t a)
tn = TN { _tnInputs = vecToProd (vmap ((\v -> IR v id (+)) . getI) vVec)
, _tnGrad = vecToProd
. listToVecDef 0 (vecLen vVec)
. map I . toList
}
traverse_ (evaluate . forceBVar) vs
insertNode tn (_bvVal <$> vs) (reflect (Proxy @s))
collectVar
:: forall a t s. (Reifies s W, Foldable t, Functor t, Num (t a), Num a)
=> t (BVar s a)
-> BVar s (t a)
collectVar !vs = unsafePerformIO $ collectVar_ vs
traverseVar'
:: forall b a f s. (Num a, Reifies s W, Traversable f)
=> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' f t v = forceBVar v
`seq` itraverse go (f (_bvVal v))
where
go :: Int -> a -> IO (BVar s a)
go i y = insertNode tn y (reflect (Proxy @s))
where
tn = TN { _tnInputs = IR v (ixt t i) (+) :< Ø
, _tnGrad = only_
}
previewVar
:: forall b a s. (Num a, Reifies s W)
=> Traversal' b a
-> BVar s b
-> Maybe (BVar s a)
previewVar t !v = unsafePerformIO $ traverseVar' (listToMaybe . toListOf t) t v
toListOfVar
:: forall b a s. (Num a, Reifies s W)
=> Traversal' b a
-> BVar s b
-> [BVar s a]
toListOfVar t !v = unsafePerformIO $ traverseVar' (toListOf t) t v
data Runner s = R { _rDelta :: !(MV.MVector s Any)
, _rInputs :: !(MV.MVector s Any)
}
initRunner
:: (PrimMonad m, PrimState m ~ s)
=> (Int, [SomeTapeNode])
-> (Int, [Some (Wit1 Num)])
-> m (Runner s)
initRunner (n, stns) (nx,xs) = do
delts <- MV.new n
for_ (zip [n1,n2..] stns) $ \(i, STN (TN{..} :: TapeNode c)) ->
MV.write delts i $ unsafeCoerce @c 0
inps <- MV.new nx
for_ (zip [0..] xs) $ \(i, Some (Wit1 :: Wit1 Num c)) ->
MV.write inps i $ unsafeCoerce @c 0
return $ R delts inps
gradRunner
:: forall m b s p. (PrimMonad m, PrimState m ~ s, Num b)
=> p b
-> Runner s
-> (Int, [SomeTapeNode])
-> m ()
gradRunner _ R{..} (n,stns) = do
when (n > 0) $
MV.write _rDelta (n 1) (unsafeCoerce @b 1)
zipWithM_ go [n1,n2..] stns
where
go :: Int -> SomeTapeNode -> m ()
go i (STN TN{..}) = do
delt <- MV.read _rDelta i
let gs = _tnGrad (unsafeCoerce delt)
zipWithPM_ propagate _tnInputs gs
propagate :: forall x. InpRef x -> I x -> m ()
propagate (IR v ln (+*)) (I d) = case _bvRef v of
BRInp i -> flip (MV.modify _rInputs) i $
unsafeCoerce . (ln %~ (+* d)) . unsafeCoerce
BRIx i -> flip (MV.modify _rDelta) i $
unsafeCoerce . (ln %~ (+* d)) . unsafeCoerce
BRC -> return ()
backpropN
:: forall as b. (Every Num as, Num b)
=> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> (b, Tuple as)
backpropN f !xs = (y, g)
where
!(!tp@(!_,!_),!y) = unsafePerformIO $ fillWengert f xs
g :: Tuple as
g = runST $ do
r <- initRunner tp (getSum `first` ifoldMap1 go xs)
gradRunner (Proxy @b) r tp
delts <- toList <$> V.freeze (_rInputs r)
return . fromMaybe (error "backpropN") $
fillProd (\_ d -> I (unsafeCoerce d)) xs delts
where
go :: forall a. Index as a -> I a -> (Sum Int, [Some (Wit1 Num)])
go i (I _) = (1, [Some (Wit1 :: Wit1 Num a)]) \\ every @_ @Num i
evalBPN
:: forall as b. ()
=> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> b
evalBPN f = snd . unsafePerformIO . fillWengert f
fillWengert
:: forall as b. ()
=> (forall s. Reifies s W => Prod (BVar s) as -> BVar s b)
-> Tuple as
-> IO ((Int, [SomeTapeNode]), b)
fillWengert f xs = do
w <- initWengert
o <- reify w $ \(Proxy :: Proxy s) -> do
let oVar = f (inpProd @s)
evaluate (forceBVar oVar)
return (_bvVal oVar)
t <- readIORef (wRef w)
traverse_ (evaluate . forceSomeTapeNode) (snd t)
return (t, o)
where
inpProd :: forall s. Prod (BVar s) as
inpProd = evalState (traverse1 (state . go . getI) xs) 0
where
go :: a -> Int -> (BVar s a, Int)
go x i = (BV (BRInp i) x, i + 1)
instance (Num a, Reifies s W) => Num (BVar s a) where
(+) = liftOp2 (+.)
() = liftOp2 (-.)
(*) = liftOp2 (*.)
negate = liftOp1 negateOp
signum = liftOp1 signumOp
abs = liftOp1 absOp
fromInteger = constVar . fromInteger
instance (Fractional a, Reifies s W) => Fractional (BVar s a) where
(/) = liftOp2 (/.)
recip = liftOp1 recipOp
fromRational = constVar . fromRational
instance (Floating a, Reifies s W) => Floating (BVar s a) where
pi = constVar pi
exp = liftOp1 expOp
log = liftOp1 logOp
sqrt = liftOp1 sqrtOp
(**) = liftOp2 (**.)
logBase = liftOp2 logBaseOp
sin = liftOp1 sinOp
cos = liftOp1 cosOp
tan = liftOp1 tanOp
asin = liftOp1 asinOp
acos = liftOp1 acosOp
atan = liftOp1 atanOp
sinh = liftOp1 sinhOp
cosh = liftOp1 coshOp
tanh = liftOp1 tanhOp
asinh = liftOp1 asinhOp
acosh = liftOp1 acoshOp
atanh = liftOp1 atanhOp
itraverse
:: forall t a b f. (Traversable t, Monad f)
=> (Int -> a -> f b) -> t a -> f (t b)
itraverse f xs = evalStateT (traverse (StateT . go) xs) 0
where
go :: a -> Int -> f (b, Int)
go x i = (,i+1) <$> f i x
ixi :: Int -> Lens' [a] a
ixi _ _ [] = error "ixi"
ixi 0 f (x:xs) = (:xs) <$> f x
ixi n f (x:xs) = (x:) <$> ixi (n 1) f xs
ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt t i f xs = stuff <$> ixi i f contents
where
contents = xs ^.. t
stuff = evalState (traverseOf t (state . const go) xs)
where
go :: [a] -> (a, [a])
go [] = error "asList"
go (y:ys) = (y, ys)