module Crypto.Lol.Cyclotomic.Tensor.RepaTensor.RTCommon
( module R
, module Data.Array.Repa.Eval
, module Data.Array.Repa.Repr.Unboxed
, Arr(..), repl, replM, eval, evalM, fTensor, ppTensor
, Trans(Id), trans, dim, (.*), (@*), force
, mulMat, mulDiag
, scalarPow'
, sumS, sumAllS
) where
import Crypto.Lol.LatticePrelude as LP hiding ((!!))
import Control.DeepSeq (NFData (..))
import Control.Monad.Identity
import Control.Monad.Random
import Data.Array.Repa as R hiding (sumAllP, sumAllS, sumP,
sumS, (*^), (+^), (-^), (/^))
import Data.Array.Repa.Eval hiding (one, zero)
import Data.Array.Repa.Repr.Unboxed
import Data.Coerce
import Data.Singletons
import Data.Singletons.Prelude hiding ((:.))
import Data.Type.Natural hiding (Z)
import Data.Typeable
import qualified Data.Vector.Unboxed as U
import Test.QuickCheck
newtype Arr (m :: Factored) r = Arr (Array U DIM1 r)
deriving (Eq, Show, Typeable, NFData)
type role Arr nominal nominal
repl :: forall m r . (Fact m, Unbox r) => r -> Arr m r
repl = let n = proxy totientFact (Proxy::Proxy m)
in Arr . fromUnboxed (Z:.n) . U.replicate n
replM :: forall m r mon . (Fact m, Unbox r, Monad mon)
=> mon r -> mon (Arr m r)
replM = let n = proxy totientFact (Proxy::Proxy m)
in liftM (Arr . fromUnboxed (Z:.n)) . U.replicateM n
instance (Unbox r) => NFData (Array U DIM1 r) where
rnf x = deepSeqArray x ()
instance (Unbox r, Random r, Fact m) => Random (Arr m r) where
random = runRand $ replM (liftRand random)
randomR = error "randomR nonsensical for Arr"
instance (Arbitrary r, Unbox r, Fact m) => Arbitrary (Arr m r) where
arbitrary = replM arbitrary
shrink = shrinkNothing
fTensor :: forall m r mon . (Fact m, Monad mon, Unbox r)
=> (forall pp . (PPow pp) => TaggedT pp mon (Trans r))
-> TaggedT m mon (Trans r)
fTensor func = tagT $ go $ sUnF (sing :: SFactored m)
where
go :: Sing (pplist :: [PrimePower]) -> mon (Trans r)
go spps = case spps of
SNil -> return $ Id 1
(SCons spp rest) -> do
rest' <- go rest
func' <- withWitnessT func spp
return $ rest' @* func'
ppTensor :: forall pp r mon . (PPow pp, Monad mon)
=> (forall p . (NatC p) => TaggedT p mon (Trans r))
-> TaggedT pp mon (Trans r)
ppTensor func = tagT $ case (sing :: SPrimePower pp) of
(SPP (STuple2 sp (SS se1))) -> do
func' <- withWitnessT func sp
let lts = withWitness valuePPow (SPP (STuple2 sp se1))
return $ Id lts @* func'
data Tensorable r = Tensorable
Int (forall rep . Source rep r => Array rep DIM2 r -> Array D DIM2 r)
type TransC r = (Tensorable r, Int, Int)
data Trans r = Id Int
| TSnoc (Trans r) (TransC r)
dimC :: TransC r -> Int
dimC (Tensorable d _, l, r) = l*d*r
dim :: Trans r -> Int
dim (Id n) = n
dim (TSnoc _ f) = dimC f
trans :: Int -> (forall rep . Source rep r => Array rep DIM2 r -> Array D DIM2 r) -> Trans r
trans d f = TSnoc (Id d) (Tensorable d f, 1, 1)
(.*) :: Trans r -> Trans r -> Trans r
f .* g | dim f == dim g = f ..* g
| otherwise = error $ "(.*): transform dimensions don't match "
LP.++ show (dim f) LP.++ ", " LP.++ show (dim g)
where
f' ..* (Id _) = f'
f' ..* (TSnoc rest g') = TSnoc (f' ..* rest) g'
(@*) :: Trans r -> Trans r -> Trans r
(Id n) @* (Id m) = Id (n*m)
i@(Id n) @* (TSnoc g' (g, l, r)) = TSnoc (i @* g') (g, n*l, r)
(TSnoc f' (f, l, r)) @* i@(Id n) = TSnoc (f' @* i) (f, l, r*n)
f @* g = (f @* Id (dim g)) .* (Id (dim f) @* g)
evalC :: (Unbox r) => TransC r -> Array U DIM1 r -> Array U DIM1 r
evalC (Tensorable d f, _, r) arr =
arr `deepSeqArray` force $ unexpose r $ f $ expose d r arr
eval :: (Unbox r) => Tagged m (Trans r) -> Arr m r -> Arr m r
eval x = coerce $ eval' $ untag x
where eval' (Id _) = id
eval' (TSnoc rest f) = eval' rest . evalC f
evalM :: (Unbox r, Monad mon) => TaggedT m mon (Trans r) -> mon (Arr m r -> Arr m r)
evalM = liftM (eval . return) . untagT
expose !d !r !arr =
let (sh :. sz) = extent arr
f (s :. i :. j) = let imodr = i `mod` r
idx = (iimodr)*d + j*r + imodr
in arr ! (s :. idx)
in fromFunction (sh :. sz `div` d :. d) f
unexpose !r !arr =
let (sh:.sz:.d) = extent arr
f (s :. i) = let (idivr,imodr) = i `divMod` r
(idivrd,j) = idivr `divMod` d
in arr ! (s :. r*idivrd + imodr :. j)
in fromFunction (sh :. sz*d) f
mulMat :: (Source r1 r, Source r2 r, Ring r, Unbox r, Elt r)
=> Array r1 DIM2 r -> Array r2 DIM2 r -> Array D DIM2 r
mulMat !m !v
= let (Z :. mrows :. mcols) = extent m
(sh :. vrows) = extent v
f (sh' :. i) = sumAllS $ R.zipWith (*) (slice m (Z:.i:.All)) $ slice v (sh':.All)
in if mcols == vrows then fromFunction (sh :. mrows) f
else error "mulMatVec: mcols != vdim"
mulDiag :: (Source r1 r, Source r2 r, Ring r, Unbox r, Elt r)
=> Array r1 DIM1 r -> Array r2 DIM2 r -> Array D DIM2 r
mulDiag !diag !arr = fromFunction (extent arr) f
where f idx@(_ :. i) = arr! idx * diag! (Z:.i)
scalarPow' :: forall m r . (Fact m, Additive r, Unbox r) => r -> Arr m r
scalarPow' = coerce . (go $ proxy totientFact (Proxy::Proxy m))
where go n !r = let fct (Z:.0) = r
fct _ = LP.zero
in force $ fromFunction (Z:.n) fct
force :: (Shape sh, Unbox r) => Array D sh r -> Array U sh r
force = runIdentity . computeP
sumS :: (Source r a, Elt a, Unbox a, Additive a, Shape sh)
=> Array r (sh :. Int) a
-> Array U sh a
sumS = foldS (+) LP.zero
sumAllS :: (Shape sh, Source r a, Elt a, Unbox a, Additive a)
=> Array r sh a
-> a
sumAllS = foldAllS (+) LP.zero