{-# LANGUAGE DataKinds               #-}
{-# LANGUAGE FlexibleContexts        #-}
{-# LANGUAGE FlexibleInstances       #-}
{-# LANGUAGE FunctionalDependencies  #-}
{-# LANGUAGE InstanceSigs            #-}
{-# LANGUAGE MagicHash               #-}
{-# LANGUAGE MultiParamTypeClasses   #-}
{-# LANGUAGE ScopedTypeVariables     #-}
{-# LANGUAGE TypeApplications        #-}
{-# LANGUAGE TypeFamilies            #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE UnboxedTuples           #-}
{-# LANGUAGE UndecidableInstances    #-}
{-# LANGUAGE UndecidableSuperClasses #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.DataFrame.Contraction
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
-- This modules provides generalization of a matrix product:
--  tensor-like contraction.
-- For matrices and vectors this is a normal matrix*matrix or vector*matrix or matrix*vector product,
-- for larger dimensions it calculates the scalar product of "adjacent" dimesnions of a tensor.
--
-----------------------------------------------------------------------------

module Numeric.DataFrame.Contraction
  ( Contraction (..), (%*)
  ) where

import GHC.Base

import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.Type
import Numeric.Dimensions



class ConcatList as bs asbs
      => Contraction (t :: Type) (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat])
                             | asbs as -> bs, asbs bs -> as, as bs -> asbs where
    -- | Generalization of a matrix product: take scalar product over one dimension
    --   and, thus, concatenate other dimesnions
    contract :: ( KnownDim m
                , PrimArray t (DataFrame t (as +: m))
                , PrimArray t (DataFrame t (m :+ bs))
                , PrimArray t (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs

-- | Tensor contraction.
--   In particular:
--     1. matrix-matrix product
--     2. matrix-vector or vector-matrix product
--     3. dot product of two vectors.
(%*) :: ( Contraction t as bs asbs
        , KnownDim m
        , PrimArray t (DataFrame t (as +: m))
        , PrimArray t (DataFrame t (m :+ bs))
        , PrimArray t (DataFrame t asbs)
        )  => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* :: DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
(%*) = DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
 PrimArray t (DataFrame t (as +: m)),
 PrimArray t (DataFrame t (m :+ bs)),
 PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
contract
{-# INLINE (%*) #-}
infixl 7 %*



instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         , Num t
         ) => Contraction t as bs asbs where

    contract :: forall m .
                ( KnownDim m
                , PrimArray t (DataFrame t (as +: m))
                , PrimArray t (DataFrame t (m :+ bs))
                , PrimArray t (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
    contract :: DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
contract DataFrame t (as +: m)
x DataFrame t (m :+ bs)
y = case (# DataFrame t (as +: m) -> Either t CumulDims
forall t a. PrimArray t a => a -> Either t CumulDims
uniqueOrCumulDims DataFrame t (as +: m)
x, DataFrame t (m :+ bs) -> Either t CumulDims
forall t a. PrimArray t a => a -> Either t CumulDims
uniqueOrCumulDims DataFrame t (m :+ bs)
y #) of
      (# Left t
x0, Left t
y0 #) -> t -> DataFrame t asbs
forall t a. PrimArray t a => t -> a
broadcast (t
x0 t -> t -> t
forall a. Num a => a -> a -> a
* t
y0)
      (# Either t CumulDims
ux, Either t CumulDims
uy #)
        | Dim m
dm <- KnownDim m => Dim m
forall k (n :: k). KnownDim n => Dim n
dim @m
        , (Int# -> t
ixX, CumulDims
xs) <- Dims (as +: m)
-> DataFrame t (as +: m)
-> Either t CumulDims
-> (Int# -> t, CumulDims)
forall (ns :: [Nat]).
PrimArray t (DataFrame t ns) =>
Dims ns
-> DataFrame t ns -> Either t CumulDims -> (Int# -> t, CumulDims)
getStepsAndIx (TypedList Dim as -> Dim m -> Dims (as +: m)
forall k (f :: k -> *) (xs :: [k]) (sy :: [k]) (y :: k).
SnocList sy y xs =>
TypedList f sy -> f y -> TypedList f xs
Snoc TypedList Dim as
forall k (ds :: [k]). Dimensions ds => Dims ds
dims Dim m
dm) DataFrame t (as +: m)
x Either t CumulDims
ux
        , (Int# -> t
ixY, CumulDims
ys) <- Dims (m :+ bs)
-> DataFrame t (m :+ bs)
-> Either t CumulDims
-> (Int# -> t, CumulDims)
forall (ns :: [Nat]).
PrimArray t (DataFrame t ns) =>
Dims ns
-> DataFrame t ns -> Either t CumulDims -> (Int# -> t, CumulDims)
getStepsAndIx (Dim m -> TypedList Dim bs -> Dims (m :+ bs)
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
Cons Dim m
dm TypedList Dim bs
forall k (ds :: [k]). Dimensions ds => Dims ds
dims) DataFrame t (m :+ bs)
y Either t CumulDims
uy
        , (# Int#
n, Int#
m, Int#
k, CumulDims
steps #) <- CumulDims -> CumulDims -> (# Int#, Int#, Int#, CumulDims #)
conSteps CumulDims
xs CumulDims
ys ->
          let loop :: Int# -> Int# -> Int# -> t -> t
loop Int#
i Int#
j Int#
l t
r | Int# -> Bool
isTrue# (Int#
l Int# -> Int# -> Int#
==# Int#
m) = t
r
                           | Bool
otherwise = Int# -> Int# -> Int# -> t -> t
loop Int#
i Int#
j (Int#
l Int# -> Int# -> Int#
+# Int#
1#)
                              (t
r t -> t -> t
forall a. Num a => a -> a -> a
+ Int# -> t
ixX (Int#
i Int# -> Int# -> Int#
*# Int#
m Int# -> Int# -> Int#
+# Int#
l) t -> t -> t
forall a. Num a => a -> a -> a
* Int# -> t
ixY (Int#
l Int# -> Int# -> Int#
*# Int#
k Int# -> Int# -> Int#
+# Int#
j))

              loop2 :: T# -> (# T#, t #)
loop2 (T# Int#
i Int#
j) | Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
==# Int#
n) = (# Int# -> Int# -> T#
T# Int#
i Int#
j, t
0 #)
                             | Int# -> Bool
isTrue# (Int#
j Int# -> Int# -> Int#
==# Int#
k) = T# -> (# T#, t #)
loop2 (Int# -> Int# -> T#
T# (Int#
i Int# -> Int# -> Int#
+# Int#
1#) Int#
0#)
                             | Bool
otherwise = (# Int# -> Int# -> T#
T# Int#
i (Int#
j Int# -> Int# -> Int#
+# Int#
1#), Int# -> Int# -> Int# -> t -> t
loop Int#
i Int#
j Int#
0# t
0 #)
          in case CumulDims
-> (T# -> (# T#, t #)) -> T# -> (# T#, DataFrame t asbs #)
forall t a s.
PrimArray t a =>
CumulDims -> (s -> (# s, t #)) -> s -> (# s, a #)
gen# CumulDims
steps T# -> (# T#, t #)
loop2 (Int# -> Int# -> T#
T# Int#
0# Int#
0#) of
              (# T#
_, DataFrame t asbs
r #) -> DataFrame t asbs
r
      where
        getStepsAndIx :: forall (ns :: [Nat])
                       . PrimArray t (DataFrame t ns)
                      => Dims ns
                      -> DataFrame t ns
                      -> Either t CumulDims
                      -> (Int# -> t, CumulDims)
        getStepsAndIx :: Dims ns
-> DataFrame t ns -> Either t CumulDims -> (Int# -> t, CumulDims)
getStepsAndIx Dims ns
_  DataFrame t ns
df (Right CumulDims
cds) = ((Int# -> DataFrame t ns -> t
forall t a. PrimArray t a => Int# -> a -> t
`ix#` DataFrame t ns
df), CumulDims
cds)
        getStepsAndIx Dims ns
ds DataFrame t ns
_  (Left  t
e)   = (\Int#
_ -> t
e, Dims ns -> CumulDims
forall k (ns :: [k]). Dims ns -> CumulDims
cumulDims Dims ns
ds)
        conSteps :: CumulDims -> CumulDims -> (# Int#, Int#, Int#, CumulDims #)
conSteps (CumulDims [Word]
xs) (CumulDims [Word]
ys) = case [Word] -> [Word] -> (Word, Word, Word, [Word])
conSteps' [Word]
xs [Word]
ys of
          (W# Word#
n, W# Word#
m, W# Word#
k, [Word]
zs)
            -> (# Word# -> Int#
word2Int# Word#
n, Word# -> Int#
word2Int# Word#
m, Word# -> Int#
word2Int# Word#
k, [Word] -> CumulDims
CumulDims [Word]
zs #)
        conSteps' :: [Word] -> [Word] -> (Word, Word, Word, [Word])
        conSteps' :: [Word] -> [Word] -> (Word, Word, Word, [Word])
conSteps' [Word
m, Word
_] (Word
_:ys :: [Word]
ys@(Word
k:[Word]
_)) = (Word
1, Word
m, Word
k, [Word]
ys)
        conSteps' (Word
nm:[Word]
ns) [Word]
cys
          | (Word
_, Word
m, Word
k, [Word]
ys) <- [Word] -> [Word] -> (Word, Word, Word, [Word])
conSteps' [Word]
ns [Word]
cys
          , Word
n <- Word
nm Word -> Word -> Word
forall a. Integral a => a -> a -> a
`quot` Word
m
            = (Word
n, Word
m, Word
k, Word
nWord -> Word -> Word
forall a. Num a => a -> a -> a
*Word
k Word -> [Word] -> [Word]
forall a. a -> [a] -> [a]
: [Word]
ys )
        conSteps' [Word]
_ [Word]
_ = [Char] -> (Word, Word, Word, [Word])
forall a. HasCallStack => [Char] -> a
error [Char]
"Numeric.DataFrame.Contraction: impossible match"

data T# = T# Int# Int#