{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module Torch.Typed.NN.Linear where

import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data
  LinearSpec
    (inputFeatures :: Nat)
    (outputFeatures :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LinearSpec
  deriving (Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
LinearSpec inputFeatures outputFeatures dtype device -> String
(Int
 -> LinearSpec inputFeatures outputFeatures dtype device -> ShowS)
-> (LinearSpec inputFeatures outputFeatures dtype device -> String)
-> ([LinearSpec inputFeatures outputFeatures dtype device]
    -> ShowS)
-> Show (LinearSpec inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
showsPrec :: Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
show :: LinearSpec inputFeatures outputFeatures dtype device -> String
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
showList :: [LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
Show, LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
(LinearSpec inputFeatures outputFeatures dtype device
 -> LinearSpec inputFeatures outputFeatures dtype device -> Bool)
-> (LinearSpec inputFeatures outputFeatures dtype device
    -> LinearSpec inputFeatures outputFeatures dtype device -> Bool)
-> Eq (LinearSpec inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
== :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
$c/= :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
/= :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
Eq)

data
  Linear
    (inputFeatures :: Nat)
    (outputFeatures :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  Linear ::
    forall inputFeatures outputFeatures dtype device.
    { forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
weight :: Parameter device dtype '[outputFeatures, inputFeatures],
      forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
bias :: Parameter device dtype '[outputFeatures]
    } ->
    Linear inputFeatures outputFeatures dtype device
  deriving (Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
[Linear inputFeatures outputFeatures dtype device] -> ShowS
Linear inputFeatures outputFeatures dtype device -> String
(Int -> Linear inputFeatures outputFeatures dtype device -> ShowS)
-> (Linear inputFeatures outputFeatures dtype device -> String)
-> ([Linear inputFeatures outputFeatures dtype device] -> ShowS)
-> Show (Linear inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
showsPrec :: Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
show :: Linear inputFeatures outputFeatures dtype device -> String
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
showList :: [Linear inputFeatures outputFeatures dtype device] -> ShowS
Show, (forall x.
 Linear inputFeatures outputFeatures dtype device
 -> Rep (Linear inputFeatures outputFeatures dtype device) x)
-> (forall x.
    Rep (Linear inputFeatures outputFeatures dtype device) x
    -> Linear inputFeatures outputFeatures dtype device)
-> Generic (Linear inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
forall x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
forall x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
from :: forall x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
$cto :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
to :: forall x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
Generic, Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
(Linear inputFeatures outputFeatures dtype device
 -> HList
      (Parameters (Linear inputFeatures outputFeatures dtype device)))
-> (Linear inputFeatures outputFeatures dtype device
    -> HList
         (Parameters (Linear inputFeatures outputFeatures dtype device))
    -> Linear inputFeatures outputFeatures dtype device)
-> Parameterized (Linear inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
flattenParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
$creplaceParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
replaceParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
Parameterized)

-- | linear
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
linearForward ::
  _ =>
  Linear _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
linearForward :: Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
     w
     w
     (CheckBroadcast
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (ComputeBroadcast
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])))
linearForward Linear {Parameter w w '[outputFeatures, inputFeatures]
Parameter w w '[outputFeatures]
weight :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
bias :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
weight :: Parameter w w '[outputFeatures, inputFeatures]
bias :: Parameter w w '[outputFeatures]
..} Tensor w w shape
input = Tensor w w '[outputFeatures, inputFeatures]
-> Tensor w w '[outputFeatures]
-> Tensor w w shape
-> Tensor
     w
     w
     (CheckBroadcast
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (ComputeBroadcast
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])))
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape'' :: [Nat]).
(shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
 shape' ~ Broadcast shape'' shape'') =>
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype shape
-> Tensor device dtype shape'
linear' (Parameter w w '[outputFeatures, inputFeatures]
-> Tensor w w '[outputFeatures, inputFeatures]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures, inputFeatures]
weight) (Parameter w w '[outputFeatures] -> Tensor w w '[outputFeatures]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures]
bias) Tensor w w shape
input

instance
  ( shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
    shape' ~ Broadcast shape'' shape''
  ) =>
  HasForward (Linear inputFeatures outputFeatures dtype device) (Tensor device dtype shape) (Tensor device dtype shape')
  where
  forward :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
forward = Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape
-> Tensor
     device
     dtype
     (CheckBroadcast
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (ComputeBroadcast
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])))
forall {inputFeatures :: Nat} {outputFeatures :: Nat} {w :: DType}
       {w :: (DeviceType, Nat)} {shape :: [Nat]}.
Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
     w
     w
     (CheckBroadcast
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (ComputeBroadcast
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])))
linearForward
  forwardStoch :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> IO (Tensor device dtype shape')
forwardStoch = (Tensor device dtype shape' -> IO (Tensor device dtype shape')
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor device dtype shape' -> IO (Tensor device dtype shape'))
-> (Tensor device dtype shape -> Tensor device dtype shape')
-> Tensor device dtype shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Tensor device dtype shape -> Tensor device dtype shape')
 -> Tensor device dtype shape -> IO (Tensor device dtype shape'))
-> (Linear inputFeatures outputFeatures dtype device
    -> Tensor device dtype shape -> Tensor device dtype shape')
-> Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputFeatures,
    KnownNat outputFeatures,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (LinearSpec inputFeatures outputFeatures dtype device)
    (Linear inputFeatures outputFeatures dtype device)
  where
  sample :: LinearSpec inputFeatures outputFeatures dtype device
-> IO (Linear inputFeatures outputFeatures dtype device)
sample LinearSpec inputFeatures outputFeatures dtype device
LinearSpec =
    Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device
Linear (Parameter device dtype '[outputFeatures, inputFeatures]
 -> Parameter device dtype '[outputFeatures]
 -> Linear inputFeatures outputFeatures dtype device)
-> IO (Parameter device dtype '[outputFeatures, inputFeatures])
-> IO
     (Parameter device dtype '[outputFeatures]
      -> Linear inputFeatures outputFeatures dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[outputFeatures, inputFeatures]
-> IO (Parameter device dtype '[outputFeatures, inputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[outputFeatures, inputFeatures]
 -> IO (Parameter device dtype '[outputFeatures, inputFeatures]))
-> IO (Tensor device dtype '[outputFeatures, inputFeatures])
-> IO (Parameter device dtype '[outputFeatures, inputFeatures])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[outputFeatures, inputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) IO
  (Parameter device dtype '[outputFeatures]
   -> Linear inputFeatures outputFeatures dtype device)
-> IO (Parameter device dtype '[outputFeatures])
-> IO (Linear inputFeatures outputFeatures dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[outputFeatures]
-> IO (Parameter device dtype '[outputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[outputFeatures]
 -> IO (Parameter device dtype '[outputFeatures]))
-> IO (Tensor device dtype '[outputFeatures])
-> IO (Parameter device dtype '[outputFeatures])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[outputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)