{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.NN.Linear where
import Data.List
import GHC.Generics
import Data.Singletons.Prelude.List hiding (All)
import Numeric.Backprop
import Numeric.Dimensions
import System.IO.Unsafe
import Debug.Trace
import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import Torch.Indef.Static.Tensor.Math
import Torch.Indef.Static.Tensor.Math.Blas
import Torch.Indef.Static.Tensor.Math.Pointwise
import Torch.Indef.Static.Tensor.Math.Pointwise.Signed ()
import Torch.Indef.Static.Tensor.Math.Pairwise (Pairwise(..))
import Torch.Indef.Static.NN.Backprop ()
import qualified Torch.Indef.Dynamic.NN as Dynamic
import qualified Torch.Indef.Dynamic.Tensor.Math as Dynamic
import qualified Torch.Indef.Dynamic.Tensor.Math.Pointwise as Dynamic
import qualified Torch.Indef.Dynamic.Tensor.Math.Pairwise as Dynamic
newtype Linear i o
= Linear { getTensors :: (Tensor '[i, o], Tensor '[o]) }
deriving (Eq, Generic)
instance (KnownDim i, KnownDim o) => Show (Linear i o) where
show c = intercalate ","
[ "Linear ("
++ "input: " ++ show (inputSize c)
, " output: " ++ show (outputSize c)
++ ")"
]
instance (KnownDim i, KnownDim o) => Backprop (Linear i o) where
zero = const . Linear $ (constant 0, constant 0)
one = const . Linear $ (constant 1, constant 1)
add (Linear (a0, b0)) (Linear (a1, b1)) = unsafePerformIO $ do
Dynamic.cadd_ (asDynamic a1) 1 (asDynamic a0)
Dynamic.cadd_ (asDynamic b1) 1 (asDynamic b0)
pure (Linear (a1, b1))
{-# NOINLINE add #-}
instance (KnownDim i, KnownDim o) => Num (Linear i o) where
(+) (Linear (a0, b0)) (Linear (a1, b1)) = Linear (a0+a1, b0+b1)
(-) (Linear (a0, b0)) (Linear (a1, b1)) = Linear (a0-a1, b0-b1)
(*) (Linear (a0, b0)) (Linear (a1, b1)) = Linear (a0*a1, b0*b1)
abs (Linear (a0, b0)) = Linear (abs a0, abs b0)
fromInteger i = Linear (fromInteger i, fromInteger i)
instance (KnownDim i, KnownDim o) => Pairwise (Linear i o) HsReal where
(Linear tens) ^+ v = Linear (tens ^+ v)
(Linear tens) ^- v = Linear (tens ^- v)
(Linear tens) ^* v = Linear (tens ^* v)
(Linear tens) ^/ v = Linear (tens ^/ v)
update_
:: (KnownDim i, KnownDim o)
=> Linear i o
-> HsReal
-> Linear i o
-> IO ()
update_ (Linear (w, b)) lr (Linear (gw, gb)) = do
Dynamic.cadd_ (asDynamic w) lr (asDynamic gw)
Dynamic.cadd_ (asDynamic b) lr (asDynamic gb)
update
:: (KnownDim i, KnownDim o)
=> Linear i o
-> HsReal
-> Linear i o
-> Linear i o
update layer lr grads = layer + (grads ^* lr)
weights :: Linear i o -> Tensor '[i, o]
weights (Linear (w, _)) = w
bias :: Linear i o -> Tensor '[o]
bias (Linear (_, b)) = b
inputSize :: forall i o . KnownDim i => Linear i o -> Int
inputSize _ = fromIntegral (dimVal (dim :: Dim i))
outputSize :: forall i o kW dW . KnownDim o => Linear i o -> Int
outputSize _ = fromIntegral (dimVal (dim :: Dim o))
mkLinear
:: (KnownDim i, KnownDim o)
=> (forall d . Dimensions d => IO (Tensor d))
-> IO (Linear i o)
mkLinear initer = Linear <$> ((,) <$> initer <*> initer)
linear
:: forall s i o
. Reifies s W
=> All KnownDim '[i,o]
=> BVar s (Linear i o)
-> BVar s (Tensor '[i])
-> BVar s (Tensor '[o])
linear = liftOp2 $ op2 $ \l i ->
(updateOutput i l, \gout -> (accGradParameters i gout l, updateGradInput i gout (weights l)))
where
updateOutput :: Tensor '[i] -> Linear i o -> Tensor '[o]
updateOutput i (Linear (w,b)) = addmv 1 b 1 (transpose2d w) i
updateGradInput :: Tensor '[i] -> Tensor '[o] -> Tensor '[i,o] -> Tensor '[i]
updateGradInput i gout w = addmv 0 (constant 0) 1 w gout
accGradParameters :: Tensor '[i] -> Tensor '[o] -> Linear i o -> Linear i o
accGradParameters i gout (Linear (w, b)) = Linear (w', b')
where
lr = 1
w' = addr 1 (constant 0) lr i gout
b' = cadd b lr gout
linearBatch
:: forall s i o b
. Reifies s W
=> All KnownDim '[b,i,o]
=> BVar s (Linear i o)
-> BVar s (Tensor '[b, i])
-> BVar s (Tensor '[b, o])
linearBatch = liftOp2 $ op2 $ \l i -> unsafePerformIO $ do
(o, getgrad) <- linearBatchIO l i
pure (o, unsafePerformIO . getgrad)
linearBatchIO
:: forall i o b
. All KnownDim '[b,i,o]
=> (Linear i o)
-> (Tensor '[b, i])
-> IO (Tensor '[b, o], Tensor '[b, o] -> IO ((Linear i o), (Tensor '[b, i])))
linearBatchIO = linearBatchWithIO (Just new) (Just new) (Just $ Linear (new, new))
linearBatchWithIO
:: forall i o b
. All KnownDim '[b,i,o]
=> Maybe (Tensor '[b, o])
-> Maybe (Tensor '[b, i])
-> Maybe (Linear i o)
-> (Linear i o)
-> (Tensor '[b, i])
-> IO (Tensor '[b, o], Tensor '[b, o] -> IO ((Linear i o), (Tensor '[b, i])))
linearBatchWithIO moutbuffer mgradinbuf mgradparambuf l i = do
let o = updateOutput i l
pure (o, \gout -> do
let g@(Linear (gw, gb)) = accGradParameters i gout l
let gin = updateGradInput i gout (weights l)
pure (g, gin))
where
lr = 1
updateOutput :: Tensor '[b, i] -> Linear i o -> Tensor '[b, o]
updateOutput i (Linear (w,b)) =
let
o = addmm 1 (constant 0) 1 i w
in
addr 1 o 1 (constant 1) b
updateGradInput :: Tensor '[b, i] -> Tensor '[b, o] -> Tensor '[i,o] -> Tensor '[b, i]
updateGradInput i gout w = addmm 0 (constant 0) 1 gout (transpose2d w)
accGradParameters :: Tensor '[b,i] -> Tensor '[b,o] -> Linear i o -> Linear i o
accGradParameters i gout (Linear (w, b)) = Linear (gw, gb)
where
gw :: Tensor '[i, o]
gw = addmm 1 (constant 0) lr (transpose2d i) gout
gb :: Tensor '[o]
gb = addmv 1 (constant 0) lr tgout (constant 1)
tgout :: Tensor '[o,b]
tgout = transpose2d gout