{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.DataFrame.Contraction
( Contraction (..), (%*)
) where
import GHC.Base
import Numeric.DataFrame.Family
import Numeric.DataFrame.Internal.Array.Class
import Numeric.Dimensions
class ConcatList as bs asbs
=> Contraction (t :: Type) (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat])
| asbs as -> bs, asbs bs -> as, as bs -> asbs where
contract :: ( KnownDim m
, PrimArray t (DataFrame t (as +: m))
, PrimArray t (DataFrame t (m :+ bs))
, PrimArray t (DataFrame t asbs)
)
=> DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
(%*) :: ( ConcatList as bs (as ++ bs)
, Contraction t as bs asbs
, KnownDim m
, PrimArray t (DataFrame t (as +: m))
, PrimArray t (DataFrame t (m :+ bs))
, PrimArray t (DataFrame t (as ++ bs))
) => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t (as ++ bs)
(%*) = contract
{-# INLINE (%*) #-}
infixl 7 %*
instance ( ConcatList as bs asbs
, Dimensions as
, Dimensions bs
, Num t
) => Contraction t as bs asbs where
contract :: forall m .
( KnownDim m
, PrimArray t (DataFrame t (as +: m))
, PrimArray t (DataFrame t (m :+ bs))
, PrimArray t (DataFrame t asbs)
)
=> DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
contract x y
| I# m <- fromIntegral $ dimVal' @m
, I# n <- fromIntegral $ totalDim' @as
, I# k <- fromIntegral $ totalDim' @bs
, nk <- n *# k
= let loop i j l r | isTrue# (l ==# m) = r
| otherwise = loop i j (l +# 1#)
(r + ix# (i +# n *# l) x * ix# (l +# m *# j) y)
loop2 (T# i j) | isTrue# (j ==# k) = (# T# i j, 0 #)
| isTrue# (i ==# n) = loop2 (T# 0# (j +# 1#))
| otherwise = (# T# (i +# 1#) j, loop i j 0# 0 #)
in case gen# nk loop2 (T# 0# 0#) of
(# _, r #) -> r
data T# = T# Int# Int#