module Crypto.Lol.Cyclotomic.Tensor.RepaTensor.RTCommon
( module R
, module Data.Array.Repa.Eval
, module Data.Array.Repa.Repr.Unboxed
, Arr(..), repl, replM, eval, evalM, fTensor, ppTensor
, Trans(Id), trans, dim, (.*), (@*), force
, mulMat, mulDiag
, scalarPow'
, sumS, sumAllS
) where
import Crypto.Lol.Prelude as LP hiding ((!!))
import Algebra.Additive as Additive (C)
import Algebra.Ring as Ring (C)
import Algebra.ZeroTestable as ZeroTestable (C)
import Control.DeepSeq (NFData (..))
import Control.Monad.Identity ()
import Control.Monad.Random
import Data.Array.Repa as R hiding (sumAllP, sumAllS, sumP,
sumS, (*^), (+^), (-^), (/^))
import Data.Array.Repa.Eval hiding (one, zero)
import Data.Array.Repa.Repr.Unboxed
import Data.Coerce
import Data.Singletons
import Data.Singletons.Prelude hiding ((:.))
import Data.Vector.Unboxed as U (replicate, replicateM)
import Test.QuickCheck
newtype Arr (m :: Factored) r = Arr (Array U DIM1 r)
deriving (Eq, Show, NFData)
type role Arr nominal nominal
repl :: forall m r . (Fact m, Unbox r) => r -> Arr m r
repl = let n = proxy totientFact (Proxy::Proxy m)
in Arr . fromUnboxed (Z:.n) . U.replicate n
replM :: forall m r mon . (Fact m, Unbox r, Monad mon)
=> mon r -> mon (Arr m r)
replM = let n = proxy totientFact (Proxy::Proxy m)
in fmap (Arr . fromUnboxed (Z:.n)) . U.replicateM n
instance (Fact m, Additive r, Unbox r) => Additive.C (Arr m r) where
zero = repl zero
(Arr a) + (Arr b) = Arr $ force $ R.zipWith (+) a b
negate (Arr a) = Arr $ force $ R.map negate a
instance (Fact m, Ring r, Unbox r) => Ring.C (Arr m r) where
one = repl one
(Arr a) * (Arr b) = Arr $ force $ R.zipWith (*) a b
fromInteger = repl . fromInteger
instance (ZeroTestable r, Unbox r, Elt r) => ZeroTestable.C (Arr m r) where
isZero (Arr a)
= isZero $ foldAllS (\ x y -> if isZero x then y else x) (a R.! (Z:.0)) a
instance (Unbox r) => NFData (Array U DIM1 r) where
rnf x = deepSeqArray x ()
instance (Unbox r, Random r, Fact m) => Random (Arr m r) where
random = runRand $ replM (liftRand random)
randomR = error "randomR nonsensical for Arr"
instance (Arbitrary r, Unbox r, Fact m) => Arbitrary (Arr m r) where
arbitrary = replM arbitrary
shrink = shrinkNothing
fTensor :: forall m r mon . (Fact m, Monad mon)
=> (forall pp . (PPow pp) => TaggedT pp mon (Trans r))
-> TaggedT m mon (Trans r)
fTensor func = tagT $ go $ sUnF (sing :: SFactored m)
where
go :: Sing (pplist :: [PrimePower]) -> mon (Trans r)
go spps = case spps of
SNil -> return $ Id 1
(SCons spp rest) -> do
rest' <- go rest
func' <- withWitnessT func spp
return $ rest' @* func'
ppTensor :: forall pp r mon . (PPow pp, Monad mon)
=> (forall p . (Prime p) => TaggedT p mon (Trans r))
-> TaggedT pp mon (Trans r)
ppTensor func = tagT $ case (sing :: SPrimePower pp) of
pp@(SPP (STuple2 sp _)) -> do
func' <- withWitnessT func sp
let lts = withWitness valuePPow pp `div` withWitness valuePrime sp
return $ Id lts @* func'
data Tensorable r = Tensorable
!Int !(forall rep . Source rep r => Array rep DIM2 r -> Array D DIM2 r)
type TransC r = (Tensorable r, Int, Int)
data Trans r = Id !Int
| TSnoc !(Trans r) !(TransC r)
dimC :: TransC r -> Int
dimC (Tensorable d _, l, r) = l*d*r
dim :: Trans r -> Int
dim (Id n) = n
dim (TSnoc _ f) = dimC f
trans :: Int -> (forall rep . Source rep r => Array rep DIM2 r -> Array D DIM2 r) -> Trans r
trans d f = TSnoc (Id d) (Tensorable d f, 1, 1)
(.*) :: Trans r -> Trans r -> Trans r
f .* g | dim f == dim g = f ..* g
| otherwise = error $ "(.*): transform dimensions don't match "
LP.++ show (dim f) LP.++ ", " LP.++ show (dim g)
where
f' ..* (Id _) = f'
f' ..* (TSnoc rest g') = TSnoc (f' ..* rest) g'
(@*) :: Trans r -> Trans r -> Trans r
(Id n) @* (Id m) = Id (n*m)
i@(Id n) @* (TSnoc g' (g, l, r)) = TSnoc (i @* g') (g, n*l, r)
(TSnoc f' (f, l, r)) @* i@(Id n) = TSnoc (f' @* i) (f, l, r*n)
f @* g = (f @* Id (dim g)) .* (Id (dim f) @* g)
evalC :: (Unbox r) => TransC r -> Array U DIM1 r -> Array U DIM1 r
evalC (Tensorable d f, _, r) = force . unexpose r . f . expose d r
eval :: (Unbox r) => Tagged m (Trans r) -> Arr m r -> Arr m r
eval x = coerce $ eval' $ untag x
where eval' (Id _) = id
eval' (TSnoc rest f) = eval' rest . evalC f
evalM :: (Unbox r, Monad mon) => TaggedT m mon (Trans r) -> mon (Arr m r -> Arr m r)
evalM = fmap (eval . return) . untagT
expose :: (Source r1 r)
=> Int -> Int -> Array r1 DIM1 r -> Array D DIM2 r
expose !d !r !arr =
let (Z :. sz) = extent arr
f (Z :. i :. j) = let imodr = i `mod` r
in (Z :. (iimodr)*d + j*r + imodr)
in backpermute (Z :. sz `div` d :. d) f arr
unexpose :: (Source r1 r) => Int -> Array r1 DIM2 r -> Array D DIM1 r
unexpose !r !arr =
let (Z :. sz :. d) = extent arr
f (Z :. i) = let (idivr,imodr) = i `divMod` r
(idivrd,j) = idivr `divMod` d
in (Z :. r*idivrd + imodr :. j)
in backpermute (Z :. sz*d) f arr
mulMat :: (Source r1 r, Source r2 r, Ring r, Unbox r, Elt r)
=> Array r1 DIM2 r -> Array r2 DIM2 r -> Array D DIM2 r
mulMat !m !v
= let (Z :. mrows :. mcols) = extent m
(sh :. vrows) = extent v
f (sh' :. i) = sumAllS $ R.zipWith (*) (slice m (Z:.i:.All)) $ slice v (sh':.All)
in if mcols == vrows then fromFunction (sh :. mrows) f
else error "mulMatVec: mcols != vdim"
mulDiag :: (Source r1 r, Source r2 r, Ring r)
=> Array r1 DIM1 r -> Array r2 DIM2 r -> Array D DIM2 r
mulDiag !diag !arr = fromFunction (extent arr) f
where f idx@(_ :. i) = (arr ! idx) * (diag ! (Z:.i))
scalarPow' :: forall m r . (Fact m, Additive r, Unbox r) => r -> Arr m r
scalarPow' = coerce . go (proxy totientFact (Proxy::Proxy m))
where go n !r = let fct (Z:.0) = r
fct _ = LP.zero
in force $ fromFunction (Z:.n) fct
force :: (Shape sh, Unbox r) => Array D sh r -> Array U sh r
force = computeS
sumS :: (Source r a, Elt a, Unbox a, Additive a, Shape sh)
=> Array r (sh :. Int) a
-> Array U sh a
sumS = foldS (+) LP.zero
sumAllS :: (Shape sh, Source r a, Elt a, Unbox a, Additive a)
=> Array r sh a
-> a
sumAllS = foldAllS (+) LP.zero