{-# LANGUAGE CPP                    #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UnboxedTuples          #-}
{-# LANGUAGE UndecidableInstances   #-}
{-# LANGUAGE InstanceSigs           #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.DataFrame.Contraction
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
-- This modules provides generalization of a matrix product:
--  tensor-like contraction.
-- For matrices and vectors this is a normal matrix*matrix or vector*matrix or matrix*vector product,
-- for larger dimensions it calculates the scalar product of "adjacent" dimesnions of a tensor.
--
-----------------------------------------------------------------------------

module Numeric.DataFrame.Contraction
  ( Contraction (..), (%*)
  ) where

#include "MachDeps.h"

import           Data.Int               (Int16, Int32, Int64, Int8)
import           Data.Word              (Word16, Word32, Word64, Word8)
import           Data.Type.Equality     ((:~:) (..))
import           GHC.Base               (runRW#)
import           GHC.Prim
import           GHC.Types              (Int (..), RuntimeRep (..), Type,
                                         Word (..), isTrue#)
import           Unsafe.Coerce          (unsafeCoerce)

import           Numeric.Array.Family
import           Numeric.Commons
import           Numeric.DataFrame.Type
import           Numeric.Dimensions
import           Numeric.TypeLits



class ConcatList as bs asbs
      => Contraction (t :: Type) (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat])
                             | asbs as -> bs, asbs bs -> as, as bs -> asbs where
    -- | Generalization of a matrix product: take scalar product over one dimension
    --   and, thus, concatenate other dimesnions
    contract :: ( KnownDim m
                , PrimBytes (DataFrame t (as +: m))
                , PrimBytes (DataFrame t (m :+ bs))
                , PrimBytes (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs

-- | Tensor contraction.
--   In particular:
--     1. matrix-matrix product
--     2. matrix-vector or vector-matrix product
--     3. dot product of two vectors.
(%*) :: ( ConcatList as bs (as ++ bs)
        , Contraction t as bs asbs
        , KnownDim m
        , PrimBytes (DataFrame t (as +: m))
        , PrimBytes (DataFrame t (m :+ bs))
        , PrimBytes (DataFrame t (as ++ bs))
        )  => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t (as ++ bs)
(%*) = contract
{-# INLINE (%*) #-}
infixl 7 %*


--------------------------------------------------------------------------------

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Float as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Float (m : bs) ) :~: 'FloatRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Float (m : bs) ) :~:  Float#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Float (as +: m)) :~: 'FloatRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Float (as +: m)) :~:  Float#
        = prodF n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy


instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Double as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Double (m : bs) ) :~: 'DoubleRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Double (m : bs) ) :~:  Double#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Double (as +: m)) :~: 'DoubleRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Double (as +: m)) :~:  Double#
        = prodD n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Int as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int (m : bs) ) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int (m : bs) ) :~:  Int#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int (as +: m)) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int (as +: m)) :~:  Int#
        = prodI n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Int8 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int8 (m : bs) ) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int8 (m : bs) ) :~:  Int#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int8 (as +: m)) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int8 (as +: m)) :~:  Int#
        = prodI8 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Int16 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int16 (m : bs) ) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int16 (m : bs) ) :~:  Int#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int16 (as +: m)) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int16 (as +: m)) :~:  Int#
        = prodI16 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Int32 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int32 (m : bs) ) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int32 (m : bs) ) :~:  Int#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int32 (as +: m)) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int32 (as +: m)) :~:  Int#
        = prodI32 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy


instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Int64 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
#if WORD_SIZE_IN_BITS < 64
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int64 (m : bs) ) :~: 'Int64Rep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (m : bs) ) :~:  Int64#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int64 (as +: m)) :~: 'Int64Rep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (as +: m)) :~:  Int64#
#else
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int64 (m : bs) ) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (m : bs) ) :~:  Int#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Int64 (as +: m)) :~: 'IntRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (as +: m)) :~:  Int#
#endif
        = prodI64 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy



instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Word as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word (m : bs) ) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word (m : bs) ) :~:  Word#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word (as +: m)) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word (as +: m)) :~:  Word#
        = prodW n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Word8 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word8 (m : bs) ) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word8 (m : bs) ) :~:  Word#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word8 (as +: m)) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word8 (as +: m)) :~:  Word#
        = prodW8 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Word16 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word16 (m : bs) ) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word16 (m : bs) ) :~:  Word#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word16 (as +: m)) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word16 (as +: m)) :~:  Word#
        = prodW16 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Word32 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word32 (m : bs) ) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word32 (m : bs) ) :~:  Word#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word32 (as +: m)) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word32 (as +: m)) :~:  Word#
        = prodW32 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy

instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         ) => Contraction Word64 as bs asbs where
    contract x y
        | (pm :: Proxy m) <- getM y
        , I# m <- intNatVal pm
        , I# n <- totalDim (Proxy @as)
        , I# k <- totalDim (Proxy @bs)
#if WORD_SIZE_IN_BITS < 64
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word64 (m : bs) ) :~: 'Word64Rep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (m : bs) ) :~:  Word64#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word64 (as +: m)) :~: 'Word64Rep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (as +: m)) :~:  Word64#
#else
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word64 (m : bs) ) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (m : bs) ) :~:  Word#
        , Refl <- unsafeCoerce Refl :: ElemRep  (Array Word64 (as +: m)) :~: 'WordRep
        , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (as +: m)) :~:  Word#
#endif
        = prodW64 n m k x y
      where
        getM :: forall m p . p (m ': bs) -> Proxy m
        getM _ = Proxy


prodF :: (PrimBytes a, PrimBytes b, PrimBytes c
         , ElemPrim a ~ Float#, ElemRep a ~ 'FloatRep
         , ElemPrim b ~ Float#, ElemRep b ~ 'FloatRep
         ) => Int# -> Int# -> Int# -> a -> b -> c
prodF n m k x y = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r `plusFloat#` timesFloat# (ix (i +# n *# l) x)
                                                                                           (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeFloatArray# marr (i +# n *# j) (loop' i j 0# 0.0#) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodF #-}

prodD :: (PrimBytes a, PrimBytes b, PrimBytes c
         , ElemPrim a ~ Double#, ElemRep a ~ 'DoubleRep
         , ElemPrim b ~ Double#, ElemRep b ~ 'DoubleRep
         ) => Int# -> Int# -> Int# -> a -> b -> c
prodD n m k x y= case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r +## (*##) (ix (i +# n *# l) x)
                                                                            (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeDoubleArray# marr (i +# n *# j) (loop' i j 0# 0.0##) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodD #-}

prodI :: (PrimBytes a, PrimBytes b, PrimBytes c
         , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep
         , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep
         ) => Int# -> Int# -> Int# -> a -> b -> c
prodI n m k x y= case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x)
                                                                          (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeIntArray# marr (i +# n *# j) (loop' i j 0# 0#) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodI #-}

prodI8 :: (PrimBytes a, PrimBytes b, PrimBytes c
         , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep
         , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep
         ) => Int# -> Int# -> Int# -> a -> b -> c
prodI8 n m k x y= case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x)
                                                                          (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeInt8Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodI8 #-}


prodI16 :: ( PrimBytes a, PrimBytes b, PrimBytes c
           , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep
           , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep
           ) => Int# -> Int# -> Int# -> a -> b -> c
prodI16 n m k x y= case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x)
                                                                          (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeInt16Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodI16 #-}


prodI32 :: ( PrimBytes a, PrimBytes b, PrimBytes c
           , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep
           , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep
           ) => Int# -> Int# -> Int# -> a -> b -> c
prodI32 n m k x y= case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x)
                                                                          (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeInt32Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodI32 #-}


prodI64 :: ( PrimBytes a, PrimBytes b, PrimBytes c
#if WORD_SIZE_IN_BITS < 64
           , ElemPrim a ~ Int64#, ElemRep a ~ 'Int64Rep
           , ElemPrim b ~ Int64#, ElemRep b ~ 'Int64Rep
#else
           , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep
           , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep
#endif
           ) => Int# -> Int# -> Int# -> a -> b -> c
#if WORD_SIZE_IN_BITS < 64
prodI64 = undefined
#else
prodI64 n m k x y= case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x)
                                                                          (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeInt64Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodI64 #-}
#endif

prodW :: ( PrimBytes a, PrimBytes b, PrimBytes c
         , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep
         , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep
         ) => Int# -> Int# -> Int# -> a -> b -> c
prodW n m k x y = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x)
                                                                                         (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeWordArray# marr (i +# n *# j) (loop' i j 0# 0##) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodW #-}

prodW8 :: ( PrimBytes a, PrimBytes b, PrimBytes c
          , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep
          , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep
          ) => Int# -> Int# -> Int# -> a -> b -> c
prodW8 n m k x y = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x)
                                                                                         (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeWord8Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodW8 #-}


prodW16 :: ( PrimBytes a, PrimBytes b, PrimBytes c
           , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep
           , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep
           ) => Int# -> Int# -> Int# -> a -> b -> c
prodW16 n m k x y = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x)
                                                                                         (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeWord16Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodW16 #-}

prodW32 :: ( PrimBytes a, PrimBytes b, PrimBytes c
           , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep
           , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep
           ) => Int# -> Int# -> Int# -> a -> b -> c
prodW32 n m k x y = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x)
                                                                                         (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeWord32Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodW32 #-}

prodW64 :: ( PrimBytes a, PrimBytes b, PrimBytes c
#if WORD_SIZE_IN_BITS < 64
           , ElemPrim a ~ Word64#, ElemRep a ~ 'Word64Rep
           , ElemPrim b ~ Word64#, ElemRep b ~ 'Word64Rep
#else
           , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep
           , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep
#endif
           ) => Int# -> Int# -> Int# -> a -> b -> c
#if WORD_SIZE_IN_BITS < 64
prodW64 = undefined
#else
prodW64 n m k x y = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) ->
           let loop' i j l r | isTrue# (l ==# m) = r
                             | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x)
                                                                                         (ix (l +# m *# j) y))
           in case loop2# n k
               (\i j s' -> writeWord64Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes (# 0#, n *# k,  r #)
    where
      bs = n *# k *# elementByteSize x
{-# INLINE prodW64 #-}
#endif

-- | Do something in a loop for int i from 0 to n-1 and j from 0 to m-1
loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s) -> State# s -> State# s
loop2# n m f = loop' 0# 0#
  where
    loop' i j s | isTrue# (j ==# m) = s
                | isTrue# (i ==# n) = loop' 0# (j +# 1#) s
                | otherwise         = case f i j s of s1 -> loop' (i +# 1#) j s1
{-# INLINE loop2# #-}