{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE InstanceSigs           #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UnboxedTuples          #-}
{-# LANGUAGE UndecidableInstances   #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.DataFrame.Contraction
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
-- This modules provides generalization of a matrix product:
--  tensor-like contraction.
-- For matrices and vectors this is a normal matrix*matrix or vector*matrix or matrix*vector product,
-- for larger dimensions it calculates the scalar product of "adjacent" dimesnions of a tensor.
--
-----------------------------------------------------------------------------

module Numeric.DataFrame.Contraction
  ( Contraction (..), (%*)
  ) where

import           GHC.Base

import           Numeric.DataFrame.Family
import           Numeric.DataFrame.Internal.Array.Class
import           Numeric.Dimensions



class ConcatList as bs asbs
      => Contraction (t :: Type) (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat])
                             | asbs as -> bs, asbs bs -> as, as bs -> asbs where
    -- | Generalization of a matrix product: take scalar product over one dimension
    --   and, thus, concatenate other dimesnions
    contract :: ( KnownDim m
                , PrimArray t (DataFrame t (as +: m))
                , PrimArray t (DataFrame t (m :+ bs))
                , PrimArray t (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs

-- | Tensor contraction.
--   In particular:
--     1. matrix-matrix product
--     2. matrix-vector or vector-matrix product
--     3. dot product of two vectors.
(%*) :: ( ConcatList as bs (as ++ bs)
        , Contraction t as bs asbs
        , KnownDim m
        , PrimArray t (DataFrame t (as +: m))
        , PrimArray t (DataFrame t (m :+ bs))
        , PrimArray t (DataFrame t (as ++ bs))
        )  => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t (as ++ bs)
(%*) = contract
{-# INLINE (%*) #-}
infixl 7 %*



instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         , Num t
         ) => Contraction t as bs asbs where
    contract :: forall m .
                ( KnownDim m
                , PrimArray t (DataFrame t (as +: m))
                , PrimArray t (DataFrame t (m :+ bs))
                , PrimArray t (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
    contract x y
        | I# m <- fromIntegral $ dimVal' @m
        , I# n <- fromIntegral $ totalDim' @as
        , I# k <- fromIntegral $ totalDim' @bs
        , nk <- n *# k
        = let loop i j l r | isTrue# (l ==# m) = r
                           | otherwise = loop i j (l +# 1#)
                              (r + ix# (i +# n *# l) x * ix# (l +# m *# j) y)

              loop2 (T# i j) | isTrue# (j ==# k) = (# T# i j, 0 #)
                             | isTrue# (i ==# n) = loop2 (T# 0# (j +# 1#))
                             | otherwise = (# T# (i +# 1#) j, loop i j 0# 0 #)
          in case gen# nk loop2 (T# 0# 0#) of
              (# _, r #) -> r

data T# = T# Int# Int#