{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE InstanceSigs #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.DataFrame.Contraction -- Copyright : (c) Artem Chirkin -- License : BSD3 -- -- Maintainer : chirkin@arch.ethz.ch -- -- This modules provides generalization of a matrix product: -- tensor-like contraction. -- For matrices and vectors this is a normal matrix*matrix or vector*matrix or matrix*vector product, -- for larger dimensions it calculates the scalar product of "adjacent" dimesnions of a tensor. -- ----------------------------------------------------------------------------- module Numeric.DataFrame.Contraction ( Contraction (..), (%*) ) where #include "MachDeps.h" import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) import Data.Type.Equality ((:~:) (..)) import GHC.Base (runRW#) import GHC.Prim import GHC.Types (Int (..), RuntimeRep (..), Type, Word (..), isTrue#) import Unsafe.Coerce (unsafeCoerce) import Numeric.Array.Family import Numeric.Commons import Numeric.DataFrame.Type import Numeric.Dimensions import Numeric.TypeLits class ConcatList as bs asbs => Contraction (t :: Type) (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) | asbs as -> bs, asbs bs -> as, as bs -> asbs where -- | Generalization of a matrix product: take scalar product over one dimension -- and, thus, concatenate other dimesnions contract :: ( KnownDim m , PrimBytes (DataFrame t (as +: m)) , PrimBytes (DataFrame t (m :+ bs)) , PrimBytes (DataFrame t asbs) ) => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs -- | Tensor contraction. -- In particular: -- 1. matrix-matrix product -- 2. matrix-vector or vector-matrix product -- 3. dot product of two vectors. (%*) :: ( ConcatList as bs (as ++ bs) , Contraction t as bs asbs , KnownDim m , PrimBytes (DataFrame t (as +: m)) , PrimBytes (DataFrame t (m :+ bs)) , PrimBytes (DataFrame t (as ++ bs)) ) => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t (as ++ bs) (%*) = contract {-# INLINE (%*) #-} infixl 7 %* -------------------------------------------------------------------------------- instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Float as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Float (m : bs) ) :~: 'FloatRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Float (m : bs) ) :~: Float# , Refl <- unsafeCoerce Refl :: ElemRep (Array Float (as +: m)) :~: 'FloatRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Float (as +: m)) :~: Float# = prodF n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Double as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Double (m : bs) ) :~: 'DoubleRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Double (m : bs) ) :~: Double# , Refl <- unsafeCoerce Refl :: ElemRep (Array Double (as +: m)) :~: 'DoubleRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Double (as +: m)) :~: Double# = prodD n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Int as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Int (m : bs) ) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int (m : bs) ) :~: Int# , Refl <- unsafeCoerce Refl :: ElemRep (Array Int (as +: m)) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int (as +: m)) :~: Int# = prodI n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Int8 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Int8 (m : bs) ) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int8 (m : bs) ) :~: Int# , Refl <- unsafeCoerce Refl :: ElemRep (Array Int8 (as +: m)) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int8 (as +: m)) :~: Int# = prodI8 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Int16 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Int16 (m : bs) ) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int16 (m : bs) ) :~: Int# , Refl <- unsafeCoerce Refl :: ElemRep (Array Int16 (as +: m)) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int16 (as +: m)) :~: Int# = prodI16 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Int32 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Int32 (m : bs) ) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int32 (m : bs) ) :~: Int# , Refl <- unsafeCoerce Refl :: ElemRep (Array Int32 (as +: m)) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int32 (as +: m)) :~: Int# = prodI32 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Int64 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) #if WORD_SIZE_IN_BITS < 64 , Refl <- unsafeCoerce Refl :: ElemRep (Array Int64 (m : bs) ) :~: 'Int64Rep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (m : bs) ) :~: Int64# , Refl <- unsafeCoerce Refl :: ElemRep (Array Int64 (as +: m)) :~: 'Int64Rep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (as +: m)) :~: Int64# #else , Refl <- unsafeCoerce Refl :: ElemRep (Array Int64 (m : bs) ) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (m : bs) ) :~: Int# , Refl <- unsafeCoerce Refl :: ElemRep (Array Int64 (as +: m)) :~: 'IntRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Int64 (as +: m)) :~: Int# #endif = prodI64 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Word as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Word (m : bs) ) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word (m : bs) ) :~: Word# , Refl <- unsafeCoerce Refl :: ElemRep (Array Word (as +: m)) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word (as +: m)) :~: Word# = prodW n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Word8 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Word8 (m : bs) ) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word8 (m : bs) ) :~: Word# , Refl <- unsafeCoerce Refl :: ElemRep (Array Word8 (as +: m)) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word8 (as +: m)) :~: Word# = prodW8 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Word16 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Word16 (m : bs) ) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word16 (m : bs) ) :~: Word# , Refl <- unsafeCoerce Refl :: ElemRep (Array Word16 (as +: m)) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word16 (as +: m)) :~: Word# = prodW16 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Word32 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) , Refl <- unsafeCoerce Refl :: ElemRep (Array Word32 (m : bs) ) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word32 (m : bs) ) :~: Word# , Refl <- unsafeCoerce Refl :: ElemRep (Array Word32 (as +: m)) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word32 (as +: m)) :~: Word# = prodW32 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy instance ( ConcatList as bs asbs , Dimensions as , Dimensions bs ) => Contraction Word64 as bs asbs where contract x y | (pm :: Proxy m) <- getM y , I# m <- intNatVal pm , I# n <- totalDim (Proxy @as) , I# k <- totalDim (Proxy @bs) #if WORD_SIZE_IN_BITS < 64 , Refl <- unsafeCoerce Refl :: ElemRep (Array Word64 (m : bs) ) :~: 'Word64Rep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (m : bs) ) :~: Word64# , Refl <- unsafeCoerce Refl :: ElemRep (Array Word64 (as +: m)) :~: 'Word64Rep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (as +: m)) :~: Word64# #else , Refl <- unsafeCoerce Refl :: ElemRep (Array Word64 (m : bs) ) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (m : bs) ) :~: Word# , Refl <- unsafeCoerce Refl :: ElemRep (Array Word64 (as +: m)) :~: 'WordRep , Refl <- unsafeCoerce Refl :: ElemPrim (Array Word64 (as +: m)) :~: Word# #endif = prodW64 n m k x y where getM :: forall m p . p (m ': bs) -> Proxy m getM _ = Proxy prodF :: (PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Float#, ElemRep a ~ 'FloatRep , ElemPrim b ~ Float#, ElemRep b ~ 'FloatRep ) => Int# -> Int# -> Int# -> a -> b -> c prodF n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusFloat#` timesFloat# (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeFloatArray# marr (i +# n *# j) (loop' i j 0# 0.0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodF #-} prodD :: (PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Double#, ElemRep a ~ 'DoubleRep , ElemPrim b ~ Double#, ElemRep b ~ 'DoubleRep ) => Int# -> Int# -> Int# -> a -> b -> c prodD n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +## (*##) (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeDoubleArray# marr (i +# n *# j) (loop' i j 0# 0.0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodD #-} prodI :: (PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep ) => Int# -> Int# -> Int# -> a -> b -> c prodI n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeIntArray# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodI #-} prodI8 :: (PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep ) => Int# -> Int# -> Int# -> a -> b -> c prodI8 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt8Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodI8 #-} prodI16 :: ( PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep ) => Int# -> Int# -> Int# -> a -> b -> c prodI16 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt16Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodI16 #-} prodI32 :: ( PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep ) => Int# -> Int# -> Int# -> a -> b -> c prodI32 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt32Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodI32 #-} prodI64 :: ( PrimBytes a, PrimBytes b, PrimBytes c #if WORD_SIZE_IN_BITS < 64 , ElemPrim a ~ Int64#, ElemRep a ~ 'Int64Rep , ElemPrim b ~ Int64#, ElemRep b ~ 'Int64Rep #else , ElemPrim a ~ Int#, ElemRep a ~ 'IntRep , ElemPrim b ~ Int#, ElemRep b ~ 'IntRep #endif ) => Int# -> Int# -> Int# -> a -> b -> c #if WORD_SIZE_IN_BITS < 64 prodI64 = undefined #else prodI64 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt64Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodI64 #-} #endif prodW :: ( PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep ) => Int# -> Int# -> Int# -> a -> b -> c prodW n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWordArray# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodW #-} prodW8 :: ( PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep ) => Int# -> Int# -> Int# -> a -> b -> c prodW8 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord8Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodW8 #-} prodW16 :: ( PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep ) => Int# -> Int# -> Int# -> a -> b -> c prodW16 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord16Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodW16 #-} prodW32 :: ( PrimBytes a, PrimBytes b, PrimBytes c , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep ) => Int# -> Int# -> Int# -> a -> b -> c prodW32 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord32Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodW32 #-} prodW64 :: ( PrimBytes a, PrimBytes b, PrimBytes c #if WORD_SIZE_IN_BITS < 64 , ElemPrim a ~ Word64#, ElemRep a ~ 'Word64Rep , ElemPrim b ~ Word64#, ElemRep b ~ 'Word64Rep #else , ElemPrim a ~ Word#, ElemRep a ~ 'WordRep , ElemPrim b ~ Word#, ElemRep b ~ 'WordRep #endif ) => Int# -> Int# -> Int# -> a -> b -> c #if WORD_SIZE_IN_BITS < 64 prodW64 = undefined #else prodW64 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ix (i +# n *# l) x) (ix (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord64Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes (# 0#, n *# k, r #) where bs = n *# k *# elementByteSize x {-# INLINE prodW64 #-} #endif -- | Do something in a loop for int i from 0 to n-1 and j from 0 to m-1 loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s) -> State# s -> State# s loop2# n m f = loop' 0# 0# where loop' i j s | isTrue# (j ==# m) = s | isTrue# (i ==# n) = loop' 0# (j +# 1#) s | otherwise = case f i j s of s1 -> loop' (i +# 1#) j s1 {-# INLINE loop2# #-}