{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Torch.Indef.Static.NN.Backprop where
import Numeric.Backprop
import Numeric.Dimensions
import Data.Singletons.Prelude.List (SplitAt, Product)
import qualified Data.Singletons.Prelude.List as S
import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import Torch.Indef.Static.Tensor.Math
import Torch.Indef.Static.Tensor.Math.Pointwise.Signed ()
import qualified Torch.Indef.Index as Ix
import qualified Torch.Indef.Static.Tensor as T
import System.IO.Unsafe
import qualified Torch.Indef.Dynamic.Tensor.Math as Dynamic
import qualified Torch.Indef.Dynamic.Tensor.Math.Pointwise as Dynamic
import qualified Torch.Indef.Dynamic.Tensor.Math.Pairwise as Dynamic
import Debug.Trace
instance Dimensions d => Backprop (Tensor d) where
add a b = a + b
zero = (const . constant) 0
one = (const . constant) 1
unsqueeze1dBP
:: forall s d rs ls n
. Reifies s W
=> All Dimensions '[d, (rs ++ '[1] ++ ls)]
=> '( rs, ls) ~ (SplitAt n d)
=> '( rs, 1:+ls) ~ (SplitAt n (rs ++ '[1] ++ ls))
=> (rs ++ ls) ~ d
=> Dim n
-> BVar s (Tensor d)
-> BVar s (Tensor (rs ++ '[1] ++ ls))
unsqueeze1dBP d = liftOp1 . op1 $ \t ->
(T.unsqueeze1d d t, go)
where
go :: Tensor (rs ++ '[1] ++ ls) -> Tensor d
go o = T.squeeze1d d o
squeeze1dBP
:: forall s d rs ls n
. Reifies s W
=> All Dimensions '[d, rs ++ ls]
=> All KnownDim '[n]
=> '(rs, 1:+ls) ~ (SplitAt n d)
=> d ~ (S.Take n (rs ++ ls) ++ '[1] ++ S.Drop n (rs ++ ls))
=> Dim n
-> BVar s (Tensor d)
-> BVar s (Tensor (rs ++ ls))
squeeze1dBP d = liftOp1 . op1 $ \t ->
(T.squeeze1d d t, go)
where
go :: Tensor (rs ++ ls) -> Tensor d
go o = T.unsqueeze1d (dim::Dim n) o
flattenBatchIO
:: forall d bs . (All KnownDim '[Product d, bs], All Dimensions '[bs:+d, d])
=> Product (bs:+d) ~ Product '[bs, Product d]
=> Tensor (bs:+d)
-> IO (Tensor '[bs, Product d], Tensor '[bs, Product d] -> IO (Tensor (bs:+d)))
flattenBatchIO i = do
o <- pure $ resizeAs i
pure (o, \gout -> pure $ resizeAs gout)