------------------------------------------------------------------------------- -- | -- Module : Torch.Indef.Static.NN.Backprop -- Copyright : (c) Sam Stites 2017 -- License : BSD3 -- Maintainer: sam@stites.io -- Stability : experimental -- Portability: non-portable -- -- Backprop helper instances for static tensors, as well as any helper -- functions that might work well with backprop. ------------------------------------------------------------------------------- {-# LANGUAGE TypeOperators #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} module Torch.Indef.Static.NN.Backprop where import Numeric.Backprop import Numeric.Dimensions import Data.Singletons.Prelude.List (SplitAt, Product) import qualified Data.Singletons.Prelude.List as S -- hiding (All, Drop, Take, type (++)) import Torch.Indef.Types import Torch.Indef.Static.Tensor import Torch.Indef.Static.Tensor.Math import Torch.Indef.Static.Tensor.Math.Pointwise.Signed () import qualified Torch.Indef.Index as Ix import qualified Torch.Indef.Static.Tensor as T import System.IO.Unsafe import qualified Torch.Indef.Dynamic.Tensor.Math as Dynamic import qualified Torch.Indef.Dynamic.Tensor.Math.Pointwise as Dynamic import qualified Torch.Indef.Dynamic.Tensor.Math.Pairwise as Dynamic import Debug.Trace -- instance Dimensions d => Backprop (Tensor d) where instance Dimensions d => Backprop (Tensor d) where add a b = a + b zero = (const . constant) 0 one = (const . constant) 1 -- zero a = unsafePerformIO $ Dynamic.zero_ (asDynamic a) >> pure a -- {-# NOINLINE zero #-} -- one a = unsafePerformIO $ Dynamic.onesLike_ (asDynamic a) (asDynamic a) >> pure a -- {-# NOINLINE one #-} -- add a b = unsafePerformIO $ Dynamic.cadd_ (asDynamic b) 1 (asDynamic a) >> pure b -- {-# NOINLINE add #-} -- :: Dimensions d -- => Dim n -- -> Tensor d -- -> Tensor (rs ++ '[1] ++ ls) unsqueeze1dBP :: forall s d rs ls n . Reifies s W => All Dimensions '[d, (rs ++ '[1] ++ ls)] => '( rs, ls) ~ (SplitAt n d) => '( rs, 1:+ls) ~ (SplitAt n (rs ++ '[1] ++ ls)) => (rs ++ ls) ~ d => Dim n -> BVar s (Tensor d) -> BVar s (Tensor (rs ++ '[1] ++ ls)) unsqueeze1dBP d = liftOp1 . op1 $ \t -> (T.unsqueeze1d d t, go) where go :: Tensor (rs ++ '[1] ++ ls) -> Tensor d go o = T.squeeze1d d o -- | Squeeze a dimension of size 1 out of the tensor squeeze1dBP :: forall s d rs ls n . Reifies s W => All Dimensions '[d, rs ++ ls] => All KnownDim '[n] => '(rs, 1:+ls) ~ (SplitAt n d) => d ~ (S.Take n (rs ++ ls) ++ '[1] ++ S.Drop n (rs ++ ls)) => Dim n -> BVar s (Tensor d) -> BVar s (Tensor (rs ++ ls)) squeeze1dBP d = liftOp1 . op1 $ \t -> (T.squeeze1d d t, go) where go :: Tensor (rs ++ ls) -> Tensor d go o = T.unsqueeze1d (dim::Dim n) o -- | A backprop-able 'flatten' operation with a batch dimension in IO flattenBatchIO :: forall d bs . (All KnownDim '[Product d, bs], All Dimensions '[bs:+d, d]) => Product (bs:+d) ~ Product '[bs, Product d] => Tensor (bs:+d) -> IO (Tensor '[bs, Product d], Tensor '[bs, Product d] -> IO (Tensor (bs:+d))) flattenBatchIO i = do o <- pure $ resizeAs i pure (o, \gout -> pure $ resizeAs gout) -- clip :: Reifies s W => (HsReal, HsReal) -> BVar s (Tensor '[1]) -> BVar s (Tensor '[1]) -- clip (mn,mx) = liftOp1 . op1 $ \i -> -- let -- x = case get1d i 0 of -- x | x > mx -> mx -- | x < mn -> mn -- | otherwise -> x -- in -- (scalar x, id)