-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.Math.Blas
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Blas functions.
-------------------------------------------------------------------------------
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math.Blas
  ( addmv
  , addmm
  , addr
  , addbmm
  , baddbmm

  , addmv_
  , addmm_
  , addr_
  , addbmm_
  , baddbmm_

  , dot
  , (<.>)
  ) where


import Control.Monad.Managed (with)
import Control.Monad.IO.Class (liftIO)
import Foreign hiding (with)
import GHC.Int
import System.IO.Unsafe
import Debug.Trace

import Torch.Indef.Types
import Torch.Indef.Dynamic.Tensor
import qualified Torch.Sig.Tensor.Math.Blas as Sig


blasOp
  :: (Ptr CState -> Ptr CTensor -> CReal -> Ptr CTensor -> CReal -> Ptr CTensor -> Ptr CTensor -> IO ())
  -> Dynamic -> HsReal -> Dynamic -> HsReal -> Dynamic -> Dynamic -> IO ()
blasOp fn r a x b y z =
  with2DynamicState r x $ \s' r' x' ->
    with2DynamicState y z $ \_ y' z' ->
      fn s' r' (hs2cReal a) x' (hs2cReal b) y' z'


_addmv   = blasOp Sig.c_addmv
_addmm   = blasOp Sig.c_addmm
_addr    = blasOp Sig.c_addr
_addbmm  = blasOp Sig.c_addbmm
_baddbmm = blasOp Sig.c_baddbmm

-- | Performs the dot product between two tensors. The number of elements must match: both tensors are
-- seen as a 1D vector.
dot :: Dynamic -> Dynamic -> HsAccReal
dot a b = unsafeDupablePerformIO $ flip with (fmap c2hsAccReal)
   $  Sig.c_dot
  <$> managedState
  <*> managedTensor a
  <*> managedTensor b
{-# NOINLINE dot #-}

-- class GPUTensorMathBlas t where
--   btrifact :: t -> IntTensor -> IntTensor -> Int -> t -> io ()
--   btrisolve :: t -> t -> t -> IntTensor -> io ()


-- | inline alias of 'dot'
(<.>) :: Dynamic -> Dynamic -> HsAccReal
(<.>) = dot

mkNewFunction
  :: (Dynamic -> HsReal -> Dynamic -> HsReal -> Dynamic -> Dynamic -> IO ())
  -> HsReal -> Dynamic -> HsReal -> Dynamic -> Dynamic -> Dynamic
mkNewFunction op a m b x y = unsafeDupablePerformIO $
  let r = new' (getSomeDims x)
  in op r a m b x y >> pure r
{-# NOINLINE mkNewFunction #-}

mkInplaceFunction
  :: (Dynamic -> HsReal -> Dynamic -> HsReal -> Dynamic -> Dynamic -> IO ())
  -> HsReal -> Dynamic -> HsReal -> Dynamic -> Dynamic -> IO ()
mkInplaceFunction op a m b x y = op m a m b x y

-- | Performs a matrix-vector multiplication between @mat@ (2D Tensor) and @vec2@
-- (1D Tensor) and add it to @vec1@.
--
-- Values @v1@ and @v2@ are scalars that multiply @vec1@ and @vec2@ respectively.
-- They are optional in C and we may be able to add this to the API in the future.
--
-- In other words,
--
-- @
--   res = (v1 * vec1) + (v2 * (mat * vec2))
-- @
--
-- Sizes must respect the matrix-multiplication operation: if @mat@ is a @n × m@
-- matrix, @vec2@ must be vector of size @m@ and @vec1@ must be a vector of size
-- @n@.
addmv
  :: HsReal    -- ^ v1
  -> Dynamic   -- ^ vec1
  -> HsReal    -- ^ v2
  -> Dynamic   -- ^ mat
  -> Dynamic   -- ^ vec2
  -> Dynamic -- ^ res
addmv = mkNewFunction _addmv

-- | Inline version of 'addmv', mutating @vec1@ inplace.
addmv_
  :: HsReal    -- ^ v1
  -> Dynamic   -- ^ vec1
  -> HsReal    -- ^ v2
  -> Dynamic   -- ^ mat
  -> Dynamic   -- ^ vec2
  -> IO ()
addmv_ = mkInplaceFunction _addmv

-- | Performs a matrix-matrix multiplication between @mat1@ (2D Tensor) and @mat2@ (2D Tensor).
--
-- Values @v1@ and @v2@ are scalars that multiply @M@ and @mat1 * mat2@ respectively.
-- They are optional in C and we may be able to add this to the API in the future.
--
-- In other words,
--
-- @
--   res = (v1 * M) + (v2 * mat1 * mat2)
-- @
--
-- If @mat1@ is a @n × m@ matrix, @mat2@ a @m × p@ matrix, @M@ must be a @n × p@ matrix.
addmm
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ M
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ mat1
  -> Dynamic    -- ^ mat2
  -> Dynamic -- ^ res
addmm = mkNewFunction _addmm

-- | Inline version of 'addmm', mutating @M@ inplace.
addmm_
  :: HsReal    -- ^ v1
  -> Dynamic   -- ^ M
  -> HsReal    -- ^ v2
  -> Dynamic   -- ^ mat1
  -> Dynamic   -- ^ mat2
  -> IO ()
addmm_ = mkInplaceFunction _addmm

-- | Performs the outer-product between @vec1@ (1D Tensor) and @vec2@
-- (1D Tensor).
--
-- Values @v1@ and @v2@ are scalars that multiply @mat_ij@ and @vec1_i [out] vec2_j@ respectively.
-- They are optional in C and we may be able to add this to the API in the future.
--
-- Thus:
--
-- @
--   res_ij = (v1 * mat_ij) + (v2 * vec1_i * vec2_j)
-- @
--
-- If @vec1_@ is a vector of size @i@ and @vec2_j@ is a vector of size @j@, then
-- @mat_ij@ must be a matrix of size @i × j@.
addr
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ mat_ij
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ vec1_i
  -> Dynamic    -- ^ vec2_j
  -> Dynamic -- ^ res_ij
addr = mkNewFunction _addr

-- | Inline version of 'addr', mutating @mat_ij@ in-place.
addr_
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ mat_ij -- mutated inplace
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ vec1_i
  -> Dynamic    -- ^ vec2_j
  -> IO ()
addr_ = mkInplaceFunction _addr

-- | Batch matrix-matrix product of matrices stored in @batch1@ and @batch2@,
-- with a reduced add step (all matrix multiplications get accumulated in
-- a single place).
--
-- @batch1@ and @batch2@ must be 3D Tensors each containing the same number
-- of matrices. If @batch1@ is a @b × n × m@ Tensor, @batch2@ a @b × m × p@
-- Tensor, @res@ will be a @n × p@ Tensor.
--
-- In other words,
--
-- @
--   res = (v1 * M) + (v2 * sum(batch1_i * batch2_i, i = 1, b))
-- @
addbmm
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ M
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ batch1_i
  -> Dynamic    -- ^ batch2_i
  -> Dynamic -- ^ res
addbmm  = mkNewFunction     _addbmm

-- | Inline version of 'addbmm', mutating @M@ in-place.
addbmm_
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ M
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ batch1_i
  -> Dynamic    -- ^ batch2_i
  -> IO ()
addbmm_ = mkInplaceFunction _addbmm

-- | Batch matrix matrix product of matrices stored in batch1 and batch2, with
-- batch add.
--
-- @batch1@ and @batch2@ must be 3D Tensors each containing the same number of
-- matrices. If @batch1@ is a @b × n × m@ Tensor, @batch2@ a @b × m × p@ Tensor,
-- @res@ will be a @b × n × p@ Tensor.
--
-- In other words,
--
-- @
--   res_i = (v1 * M_i) + (v2 * batch1_i * batch2_i)
-- @
baddbmm
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ M_i
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ batch1_i
  -> Dynamic    -- ^ batch2_i
  -> Dynamic -- ^ res_i
baddbmm  = mkNewFunction     _baddbmm

-- | Inline version of 'baddbmm', mutating @M_i@ in-place.
baddbmm_
  :: HsReal     -- ^ v1
  -> Dynamic    -- ^ M_i
  -> HsReal     -- ^ v2
  -> Dynamic    -- ^ batch1_i
  -> Dynamic    -- ^ batch2_i
  -> IO ()
baddbmm_ = mkInplaceFunction _baddbmm