{-# LANGUAGE BangPatterns, ConstraintKinds, DataKinds, FlexibleContexts, FlexibleInstances, GADTs, GeneralizedNewtypeDeriving, KindSignatures, MultiParamTypeClasses, NoImplicitPrelude, RankNTypes, RebindableSyntax, RoleAnnotations, ScopedTypeVariables, TypeOperators #-} -- | A simple DSL for tensoring Repa arrays and other common functionality -- on Repa arrays module Crypto.Lol.Cyclotomic.Tensor.RepaTensor.RTCommon ( module R , module Data.Array.Repa.Eval , module Data.Array.Repa.Repr.Unboxed , Arr(..), repl, replM, eval, evalM, fTensor, ppTensor , Trans(Id), trans, dim, (.*), (@*), force , mulMat, mulDiag , scalarPow' , sumS, sumAllS ) where import Crypto.Lol.LatticePrelude as LP hiding ((!!)) import Control.DeepSeq (NFData (..)) import Control.Monad.Identity import Control.Monad.Random import Data.Array.Repa as R hiding (sumAllP, sumAllS, sumP, sumS, (*^), (+^), (-^), (/^)) import Data.Array.Repa.Eval hiding (one, zero) import Data.Array.Repa.Repr.Unboxed import Data.Coerce import Data.Singletons import Data.Singletons.Prelude hiding ((:.)) import Data.Type.Natural hiding (Z) import Data.Typeable import qualified Data.Vector.Unboxed as U import Test.QuickCheck -- always unboxed (manifest); intermediate calculations can use -- delayed arrays -- | Indexed newtype for 1-dimensional Unbox repa arrays newtype Arr (m :: Factored) r = Arr (Array U DIM1 r) deriving (Eq, Show, Typeable, NFData) -- the first argument, though phantom, affects representation -- CJP: why must the second arg be nominal? -- EAC: From https://ghc.haskell.org/trac/ghc/wiki/Roles#Thesolution: -- "The exception to the above algorithm is for classes: all parameters for a class default to a nominal role." -- Arr is a synonym for Array, which is an associated data type to the class Source. The parameter `r` above -- corresponds to the parameter `e` in the definition of class Source, so it's role must be nominal. type role Arr nominal nominal -- | An 'Arr' filled with the argument. repl :: forall m r . (Fact m, Unbox r) => r -> Arr m r repl = let n = proxy totientFact (Proxy::Proxy m) in Arr . fromUnboxed (Z:.n) . U.replicate n -- | Monadic version of 'repl'. replM :: forall m r mon . (Fact m, Unbox r, Monad mon) => mon r -> mon (Arr m r) replM = let n = proxy totientFact (Proxy::Proxy m) in liftM (Arr . fromUnboxed (Z:.n)) . U.replicateM n instance (Unbox r) => NFData (Array U DIM1 r) where -- EAC: Repa doesn't define any NFData instances, -- I'm hoping deepSeqArray is a reasonable approx rnf x = deepSeqArray x () instance (Unbox r, Random r, Fact m) => Random (Arr m r) where random = runRand $ replM (liftRand random) randomR = error "randomR nonsensical for Arr" instance (Arbitrary r, Unbox r, Fact m) => Arbitrary (Arr m r) where arbitrary = replM arbitrary shrink = shrinkNothing -- | For a factored index, tensors up any function defined for (and -- tagged by) any prime power fTensor :: forall m r mon . (Fact m, Monad mon, Unbox r) => (forall pp . (PPow pp) => TaggedT pp mon (Trans r)) -> TaggedT m mon (Trans r) fTensor func = tagT $ go $ sUnF (sing :: SFactored m) where go :: Sing (pplist :: [PrimePower]) -> mon (Trans r) go spps = case spps of SNil -> return $ Id 1 (SCons spp rest) -> do rest' <- go rest func' <- withWitnessT func spp return $ rest' @* func' -- | For a prime power pp > 1, tensors up any function f defined for -- (and tagged by) a prime to (I_(pp/p) \otimes f) ppTensor :: forall pp r mon . (PPow pp, Monad mon) => (forall p . (NatC p) => TaggedT p mon (Trans r)) -> TaggedT pp mon (Trans r) ppTensor func = tagT $ case (sing :: SPrimePower pp) of -- intentionally no match for zero exponents, because that is -- ill-formed and indicates an internal error (SPP (STuple2 sp (SS se1))) -> do func' <- withWitnessT func sp let lts = withWitness valuePPow (SPP (STuple2 sp se1)) return $ Id lts @* func' -- deeply embedded DSL for transformations and their various -- compositions -- (dim(f), f) where f operates on innermost dimension of array data Tensorable r = Tensorable Int (forall rep . Source rep r => Array rep DIM2 r -> Array D DIM2 r) -- transform component: a Tensorable with particular I_l, I_r type TransC r = (Tensorable r, Int, Int) -- full transform: sequence of zero or more components -- | a DSL for tensor transforms on Repa arrays data Trans r = Id Int -- ^| identity sentinel | TSnoc (Trans r) (TransC r) -- ^| (function) composition of transforms dimC :: TransC r -> Int dimC (Tensorable d _, l, r) = l*d*r -- | Returns the (linear) dimension of a transform dim :: Trans r -> Int dim (Id n) = n dim (TSnoc _ f) = dimC f -- just use dimension of head -- | smart constructor from a Tensorable trans :: Int -> (forall rep . Source rep r => Array rep DIM2 r -> Array D DIM2 r) -> Trans r trans d f = TSnoc (Id d) (Tensorable d f, 1, 1) -- | compose transforms (.*) :: Trans r -> Trans r -> Trans r f .* g | dim f == dim g = f ..* g | otherwise = error $ "(.*): transform dimensions don't match " LP.++ show (dim f) LP.++ ", " LP.++ show (dim g) where f' ..* (Id _) = f' -- drop sentinel f' ..* (TSnoc rest g') = TSnoc (f' ..* rest) g' -- | tensor/Kronecker product (otimes) (@*) :: Trans r -> Trans r -> Trans r -- merge identity transforms (Id n) @* (Id m) = Id (n*m) -- Id on left or right i@(Id n) @* (TSnoc g' (g, l, r)) = TSnoc (i @* g') (g, n*l, r) (TSnoc f' (f, l, r)) @* i@(Id n) = TSnoc (f' @* i) (f, l, r*n) -- no Ids: compose f @* g = (f @* Id (dim g)) .* (Id (dim f) @* g) evalC :: (Unbox r) => TransC r -> Array U DIM1 r -> Array U DIM1 r evalC (Tensorable d f, _, r) arr = arr `deepSeqArray` force $ unexpose r $ f $ expose d r arr -- | Creates an evaluatable Haskell function from a tensored transform eval :: (Unbox r) => Tagged m (Trans r) -> Arr m r -> Arr m r eval x = coerce $ eval' $ untag x where eval' (Id _) = id eval' (TSnoc rest f) = eval' rest . evalC f -- | Monadic version of 'eval' evalM :: (Unbox r, Monad mon) => TaggedT m mon (Trans r) -> mon (Arr m r -> Arr m r) evalM = liftM (eval . return) . untagT -- | maps the innermost dimension to a 2-dim array with innermost dim d, -- for performing a I_l \otimes f_d \otimes I_r transformation expose !d !r !arr = let (sh :. sz) = extent arr f (s :. i :. j) = let imodr = i `mod` r idx = (i-imodr)*d + j*r + imodr in arr ! (s :. idx) in fromFunction (sh :. sz `div` d :. d) f -- | inverse of expose unexpose !r !arr = let (sh:.sz:.d) = extent arr f (s :. i) = let (idivr,imodr) = i `divMod` r (idivrd,j) = idivr `divMod` d in arr ! (s :. r*idivrd + imodr :. j) in fromFunction (sh :. sz*d) f -- | general matrix multiplication along innermost dim of v mulMat :: (Source r1 r, Source r2 r, Ring r, Unbox r, Elt r) => Array r1 DIM2 r -> Array r2 DIM2 r -> Array D DIM2 r mulMat !m !v = let (Z :. mrows :. mcols) = extent m (sh :. vrows) = extent v f (sh' :. i) = sumAllS $ R.zipWith (*) (slice m (Z:.i:.All)) $ slice v (sh':.All) in if mcols == vrows then fromFunction (sh :. mrows) f else error "mulMatVec: mcols != vdim" -- | multiplication by a diagonal matrix along innermost dim mulDiag :: (Source r1 r, Source r2 r, Ring r, Unbox r, Elt r) => Array r1 DIM1 r -> Array r2 DIM2 r -> Array D DIM2 r mulDiag !diag !arr = fromFunction (extent arr) f where f idx@(_ :. i) = arr! idx * diag! (Z:.i) -- misc Tensor functions -- | Embeds a scalar into a powerful-basis representation of a Repa array, -- tagged by the cyclotomic index scalarPow' :: forall m r . (Fact m, Additive r, Unbox r) => r -> Arr m r scalarPow' = coerce . (go $ proxy totientFact (Proxy::Proxy m)) where go n !r = let fct (Z:.0) = r fct _ = LP.zero in force $ fromFunction (Z:.n) fct -- | Forces a delayed array to a manifest array. force :: (Shape sh, Unbox r) => Array D sh r -> Array U sh r -- CJP: computeS just until we figure out how to avoid nested parallel -- computation! --force = computeS force = runIdentity . computeP -- copied implementations of functions we need that normally require -- Num -- | Sum the inner-most dimension of an array sequentially sumS :: (Source r a, Elt a, Unbox a, Additive a, Shape sh) => Array r (sh :. Int) a -> Array U sh a sumS = foldS (+) LP.zero -- | Sum all array indices to a scalar sequentially sumAllS :: (Shape sh, Source r a, Elt a, Unbox a, Additive a) => Array r sh a -> a sumAllS = foldAllS (+) LP.zero