{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Crypto.Lol.Cyclotomic.Tensor.CPP (CT) where
import Algebra.Additive as Additive (C)
import Algebra.Module as Module (C)
import Algebra.Ring as Ring (C)
import Algebra.ZeroTestable as ZeroTestable (C)
import Control.Applicative hiding ((*>))
import Control.DeepSeq
import Control.Monad.Except
import Control.Monad.Identity (Identity (..), runIdentity)
import Control.Monad.Random
import Control.Monad.Trans as T (lift)
import Data.Coerce
import Data.Int
import Data.Maybe
import Data.Traversable as T
import Data.Vector.Generic as V (fromList, toList)
import Data.Vector.Storable as SV (Vector, and, convert, foldl',
fromList, generate, length, map,
replicate, replicateM, thaw,
thaw, toList, unsafeFreeze,
unsafeWith, zipWith, (!))
import Data.Vector.Storable.Mutable as SM hiding (replicate)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr
import Crypto.Lol.CRTrans
import Crypto.Lol.Cyclotomic.Tensor
import Crypto.Lol.Cyclotomic.Tensor.CPP.Backend
import Crypto.Lol.Cyclotomic.Tensor.CPP.Extension
import Crypto.Lol.Cyclotomic.Tensor.CPP.Instances ()
import Crypto.Lol.GaussRandom
import Crypto.Lol.Prelude as LP hiding (replicate,
unzip, zip)
import Crypto.Lol.Reflects
import Crypto.Lol.Tests
import Crypto.Lol.Types.FiniteField
import Crypto.Lol.Types.IFunctor
import Crypto.Lol.Types.IZipVector
import Crypto.Lol.Types.Proto
import Crypto.Lol.Types.Unsafe.RRq
import Crypto.Lol.Types.Unsafe.ZqBasic
import Data.Foldable as F
import System.IO.Unsafe (unsafePerformIO)
newtype CT' (m :: Factored) r = CT' { unCT' :: Vector r }
deriving (Show, Eq, NFData)
type role CT' representational nominal
data CT (m :: Factored) r where
CT :: Storable r => CT' m r -> CT m r
ZV :: IZipVector m r -> CT m r
deriving instance Show r => Show (CT m r)
instance Show (ArgType CT) where
show _ = "CT"
instance Eq r => Eq (CT m r) where
{-# SPECIALIZE instance Eq (CT m Int64) #-}
{-# SPECIALIZE instance Eq (CT m Double) #-}
{-# SPECIALIZE instance Eq (CT m (Complex Double)) #-}
{-# SPECIALIZE instance Eq (CT m (ZqBasic q Int64)) #-}
{-# SPECIALIZE instance Eq (CT m (RRq q Double)) #-}
(ZV x) == (ZV y) = x == y
(CT x) == (CT y) = x == y
x@(CT _) == y = x == toCT y
y == x@(CT _) = x == toCT y
{-# INLINE toCT #-}
toCT :: (Storable r) => CT m r -> CT m r
toCT = \case
v@(CT _) -> v
(ZV v) -> CT $ zvToCT' v
{-# INLINE toZV #-}
toZV :: (Fact m) => CT m r -> CT m r
toZV = \case
(CT (CT' v)) -> ZV $ fromMaybe (error "toZV: internal error") $
iZipVector $ convert v
v@(ZV _) -> v
{-# INLINE zvToCT' #-}
zvToCT' :: forall m r . (Storable r) => IZipVector m r -> CT' m r
zvToCT' = CT' . convert . unIZipVector
{-# INLINE wrap #-}
wrap :: (Storable s, Storable r) => (CT' l s -> CT' m r) -> (CT l s -> CT m r)
wrap f = \case
(CT v) -> CT $ f v
(ZV v) -> CT $ f $ zvToCT' v
{-# INLINE wrapM #-}
wrapM :: (Storable s, Storable r, Monad mon) => (CT' l s -> mon (CT' m r))
-> (CT l s -> mon (CT m r))
wrapM f = \case
(CT v) -> CT <$> f v
(ZV v) -> CT <$> f (zvToCT' v)
type family Tw (r :: *) :: * where
Tw (CT' m' r -> CT' m r) = Tagged '(m,m') (Vector r -> Vector r)
Tw (Maybe (CT' m' r -> CT' m r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)
type family Em r where
Em (CT' m r -> CT' m' r) = Tagged '(m,m') (Vector r -> Vector r)
Em (Maybe (CT' m r -> CT' m' r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)
instance (Additive r, Storable r, Fact m)
=> Additive.C (CT m r) where
{-# SPECIALIZE instance Fact m => Additive.C (CT m Int64) #-}
{-# SPECIALIZE instance Fact m => Additive.C (CT m Double) #-}
{-# SPECIALIZE instance Fact m => Additive.C (CT m (Complex Double)) #-}
{-# SPECIALIZE instance (Fact m, Reflects q Int64) => Additive.C (CT m (ZqBasic q Int64)) #-}
{-# SPECIALIZE instance Fact m => Additive.C (CT m (RRq q Double)) #-}
(CT (CT' a)) + (CT (CT' b)) = CT $ CT' $ SV.zipWith (+) a b
a + b = toCT a + toCT b
negate (CT (CT' a)) = CT $ CT' $ SV.map negate a
negate a = negate $ toCT a
zero = CT $ repl zero
instance (ZeroTestable r, Storable r) => ZeroTestable.C (CT m r) where
isZero (CT (CT' a)) = SV.foldl' (\ b x -> b && isZero x) True a
isZero (ZV v) = isZero v
instance (GFCtx fp d, Fact m, Additive (CT m fp))
=> Module.C (GF fp d) (CT m fp) where
r *> v = case v of
CT (CT' arr) -> CT $ CT' $ SV.fromList $ unCoeffs $ r *> Coeffs $ SV.toList arr
ZV zv -> ZV $ fromJust $ iZipVector $ V.fromList $ unCoeffs $ r *> Coeffs $ V.toList $ unIZipVector zv
instance (Fact m, Ring r, Storable r) => Ring.C (CT m r) where
{-# SPECIALIZE instance Fact m => Ring.C (CT m Int64) #-}
{-# SPECIALIZE instance Fact m => Ring.C (CT m Double) #-}
{-# SPECIALIZE instance Fact m => Ring.C (CT m (Complex Double)) #-}
{-# SPECIALIZE instance (Fact m, Reflects q Int64) => Ring.C (CT m (ZqBasic q Int64)) #-}
(*) = zipWithI (*)
fromInteger i = CT $ repl $ fromInteger i
instance (Fact m, Ring r, Storable r) => Module.C r (CT m r) where
{-# SPECIALIZE instance Fact m => Module.C Int64 (CT m Int64) #-}
{-# SPECIALIZE instance Fact m => Module.C Double (CT m Double) #-}
{-# SPECIALIZE instance Fact m => Module.C (Complex Double) (CT m (Complex Double)) #-}
{-# SPECIALIZE instance (Fact m, Reflects q Int64) => Module.C (ZqBasic q Int64) (CT m (ZqBasic q Int64)) #-}
(*>) r = wrap $ coerce $ SV.map (r*)
instance (ApproxEqual r, Storable r) => ApproxEqual (CT m r) where
(CT (CT' x)) =~= (CT (CT' y)) = SV.and $ SV.zipWith (=~=) x y
x =~= y = toCT x =~= toCT y
instance Fact m => Functor (CT m) where
fmap f x = pure f <*> x
instance Fact m => Applicative (CT m) where
pure = ZV . pure
(ZV f) <*> (ZV a) = ZV (f <*> a)
f@(ZV _) <*> v@(CT _) = f <*> toZV v
(CT _) <*> _ = error "CPP: internal error: function held by CT"
instance Fact m => Foldable (CT m) where
foldMap = foldMapDefault
instance Fact m => Traversable (CT m) where
traverse f r@(CT _) = T.traverse f $ toZV r
traverse f (ZV v) = ZV <$> T.traverse f v
instance TensorGaussian CT Double where
tweakedGaussianDec v = CT <$> cDispatchGaussianDouble v
instance (Reflects q Int64, Prime q, IrreduciblePoly (ZqBasic q Int64))
=> TensorCRTSet CT (ZqBasic q Int64) where
crtSetDec = (CT <$>) <$> coerceBasis crtSetDec'
instance IFunctor CT where
type IFElt CT a = Storable a
fmapI f (CT (CT' sv)) = CT $ CT' $ SV.map f sv
fmapI f a = fmapI f $ toCT a
zipWithI f (CT (CT' sv1)) (CT (CT' sv2)) = CT $ CT' $ SV.zipWith f sv1 sv2
zipWithI f a b = zipWithI f (toCT a) (toCT b)
instance Reflects q Int64 => TensorPowDec CT (ZqBasic q Int64) where
scalarPow = CT . scalarPow'
decToPow = wrap $ basicDispatch dlZq
powToDec = wrap $ basicDispatch dlinvZq
twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
embedPow = wrap $ runIdentity $ coerceEm embedPow'
embedDec = wrap $ runIdentity $ coerceEm embedDec'
coeffs = wrapM $ coerceCoeffs coeffs'
powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'
{-# INLINABLE scalarPow #-}
{-# INLINABLE decToPow #-}
{-# INLINABLE powToDec #-}
{-# INLINABLE twacePowDec #-}
{-# INLINABLE embedPow #-}
{-# INLINABLE embedDec #-}
{-# INLINABLE coeffs #-}
{-# INLINABLE powBasisPow #-}
instance Reflects q Int64 => TensorG CT (ZqBasic q Int64) where
mulGPow = wrap $ basicDispatch dmulgpowZq
mulGDec = wrap $ basicDispatch dmulgdecZq
divGPow = wrapM $ dispatchGInv dginvpowZq
divGDec = wrapM $ dispatchGInv dginvdecZq
{-# INLINABLE mulGPow #-}
{-# INLINABLE mulGDec #-}
{-# INLINABLE divGPow #-}
{-# INLINABLE divGDec #-}
instance Reflects q Int64 => TensorCRT CT Maybe (ZqBasic q Int64) where
crtFuncs = (,,,,) <$>
return (CT . repl) <*>
(wrap . untag (cZipDispatch dmulZq) <$> gCRT) <*>
(wrap . untag (cZipDispatch dmulZq) <$> gInvCRT) <*>
(wrap <$> untagT ctCRTZq) <*>
(wrap <$> untagT ctCRTInvZq)
crtExtFuncs = (,) <$> (wrap <$> coerceTw twaceCRT') <*> (wrap <$> coerceEm embedCRT')
{-# NOINLINE crtFuncs #-}
{-# NOINLINE crtExtFuncs #-}
instance TensorPowDec CT (Complex Double) where
scalarPow = CT . scalarPow'
decToPow = wrap $ basicDispatch dlC
powToDec = wrap $ basicDispatch dlinvC
twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
embedPow = wrap $ runIdentity $ coerceEm embedPow'
embedDec = wrap $ runIdentity $ coerceEm embedDec'
coeffs = wrapM $ coerceCoeffs coeffs'
powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'
{-# INLINABLE scalarPow #-}
{-# INLINABLE decToPow #-}
{-# INLINABLE powToDec #-}
{-# INLINABLE twacePowDec #-}
{-# INLINABLE embedPow #-}
{-# INLINABLE embedDec #-}
{-# INLINABLE coeffs #-}
{-# INLINABLE powBasisPow #-}
instance TensorG CT (Complex Double) where
mulGPow = wrap $ basicDispatch dmulgpowC
mulGDec = wrap $ basicDispatch dmulgdecC
divGPow = wrapM $ dispatchGInv dginvpowC
divGDec = wrapM $ dispatchGInv dginvdecC
{-# INLINABLE mulGPow #-}
{-# INLINABLE mulGDec #-}
{-# INLINABLE divGPow #-}
{-# INLINABLE divGDec #-}
instance TensorCRT CT Identity (Complex Double) where
crtFuncs = (,,,,) <$>
return (CT . repl) <*>
(wrap . untag (cZipDispatch dmulC) <$> gCRT) <*>
(wrap . untag (cZipDispatch dmulC) <$> gInvCRT) <*>
(wrap <$> untagT ctCRTC) <*>
(wrap <$> untagT ctCRTInvC)
crtExtFuncs = (,) <$> (wrap <$> coerceTw twaceCRT') <*> (wrap <$> coerceEm embedCRT')
{-# NOINLINE crtFuncs #-}
{-# NOINLINE crtExtFuncs #-}
instance TensorCRT CT Maybe (Complex Double) where
crtFuncs = return $ runIdentity crtFuncs
crtExtFuncs = return $ runIdentity crtExtFuncs
{-# NOINLINE crtFuncs #-}
{-# NOINLINE crtExtFuncs #-}
instance TensorPowDec CT Double where
scalarPow = CT . scalarPow'
decToPow = wrap $ basicDispatch dlDouble
powToDec = wrap $ basicDispatch dlinvDouble
twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
embedPow = wrap $ runIdentity $ coerceEm embedPow'
embedDec = wrap $ runIdentity $ coerceEm embedDec'
coeffs = wrapM $ coerceCoeffs coeffs'
powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'
{-# INLINABLE scalarPow #-}
{-# INLINABLE decToPow #-}
{-# INLINABLE powToDec #-}
{-# INLINABLE twacePowDec #-}
{-# INLINABLE embedPow #-}
{-# INLINABLE embedDec #-}
{-# INLINABLE coeffs #-}
{-# INLINABLE powBasisPow #-}
instance TensorG CT Double where
mulGPow = wrap $ basicDispatch dmulgpowDouble
mulGDec = wrap $ basicDispatch dmulgdecDouble
divGPow = wrapM $ dispatchGInv dginvpowDouble
divGDec = wrapM $ dispatchGInv dginvdecDouble
{-# INLINABLE mulGPow #-}
{-# INLINABLE mulGDec #-}
{-# INLINABLE divGPow #-}
{-# INLINABLE divGDec #-}
instance TensorCRT CT Maybe Double where
crtFuncs = Nothing
crtExtFuncs = Nothing
{-# INLINABLE crtFuncs #-}
{-# INLINABLE crtExtFuncs #-}
instance TensorGSqNorm CT Double where
gSqNormDec (CT v) = untag gSqNormDecDouble v
gSqNormDec (ZV v) = gSqNormDec (CT $ zvToCT' v)
{-# INLINABLE gSqNormDec #-}
instance TensorPowDec CT Int64 where
scalarPow = CT . scalarPow'
decToPow = wrap $ basicDispatch dlInt64
powToDec = wrap $ basicDispatch dlinvInt64
twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
embedPow = wrap $ runIdentity $ coerceEm embedPow'
embedDec = wrap $ runIdentity $ coerceEm embedDec'
coeffs = wrapM $ coerceCoeffs coeffs'
powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'
{-# INLINABLE scalarPow #-}
{-# INLINABLE decToPow #-}
{-# INLINABLE powToDec #-}
{-# INLINABLE twacePowDec #-}
{-# INLINABLE embedPow #-}
{-# INLINABLE embedDec #-}
{-# INLINABLE coeffs #-}
{-# INLINABLE powBasisPow #-}
instance TensorG CT Int64 where
mulGPow = wrap $ basicDispatch dmulgpowInt64
mulGDec = wrap $ basicDispatch dmulgdecInt64
divGPow = wrapM $ dispatchGInv dginvpowInt64
divGDec = wrapM $ dispatchGInv dginvdecInt64
{-# INLINABLE mulGPow #-}
{-# INLINABLE mulGDec #-}
{-# INLINABLE divGPow #-}
{-# INLINABLE divGDec #-}
instance TensorCRT CT Maybe Int64 where
crtFuncs = Nothing
crtExtFuncs = Nothing
{-# INLINABLE crtFuncs #-}
{-# INLINABLE crtExtFuncs #-}
instance TensorGSqNorm CT Int64 where
gSqNormDec (CT v) = untag gSqNormDecInt64 v
gSqNormDec (ZV v) = gSqNormDec (CT $ zvToCT' v)
{-# INLINABLE gSqNormDec #-}
instance (Reflects q Int64, Reflects q Double) => TensorPowDec CT (RRq q Double) where
scalarPow = CT . scalarPow'
decToPow = wrap $ basicDispatch dlRRq
powToDec = wrap $ basicDispatch dlinvRRq
twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
embedPow = wrap $ runIdentity $ coerceEm embedPow'
embedDec = wrap $ runIdentity $ coerceEm embedDec'
coeffs = wrapM $ coerceCoeffs coeffs'
powBasisPow :: forall m m' . (m `Divides` m') => Tagged m [CT m' (RRq q Double)]
powBasisPow =
let zqBasis = (CT <$>) <$> coerceBasis powBasisPow' :: Tagged m [CT m' (ZqBasic q Int64)] in
(fmapI fromSubgroup <$>) <$> zqBasis
{-# INLINABLE scalarPow #-}
{-# INLINABLE decToPow #-}
{-# INLINABLE powToDec #-}
{-# INLINABLE twacePowDec #-}
{-# INLINABLE embedPow #-}
{-# INLINABLE embedDec #-}
{-# INLINABLE coeffs #-}
{-# INLINABLE powBasisPow #-}
coerceTw :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m' r -> CT' m r)
coerceTw = (coerce <$>) . untagT
coerceEm :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m r -> CT' m' r)
coerceEm = (coerce <$>) . untagT
coerceCoeffs :: Tagged '(m,m') (Vector r -> [Vector r]) -> CT' m' r -> [CT' m r]
coerceCoeffs = coerce
coerceBasis :: Tagged '(m,m') [Vector r] -> Tagged m [CT' m' r]
coerceBasis = coerce
{-# INLINE dispatchGInv #-}
dispatchGInv :: forall m r . (Storable r, Fact m)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16)
-> CT' m r -> Maybe (CT' m r)
dispatchGInv =
let factors = marshalFactors $ ppsFact @m
totm = fromIntegral $ totientFact @m
numFacts = fromIntegral $ SV.length factors
in \f (CT' x) -> unsafePerformIO $ do
yout <- SV.thaw x
ret <- SM.unsafeWith yout (\pout ->
SV.unsafeWith factors (\pfac ->
f pout totm pfac numFacts))
if ret /= 0
then Just . CT' <$> unsafeFreeze yout
else return Nothing
{-# INLINE withBasicArgs #-}
withBasicArgs :: forall m r . (Fact m, Storable r)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
-> CT' m r -> IO (CT' m r)
withBasicArgs =
let factors = marshalFactors $ ppsFact @m
totm = fromIntegral $ totientFact @m
numFacts = fromIntegral $ SV.length factors
in \f (CT' x) -> do
yout <- SV.thaw x
SM.unsafeWith yout (\pout ->
SV.unsafeWith factors (\pfac ->
f pout totm pfac numFacts))
CT' <$> unsafeFreeze yout
{-# INLINE basicDispatch #-}
basicDispatch :: (Storable r, Fact m)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
-> CT' m r -> CT' m r
basicDispatch f = unsafePerformIO . withBasicArgs f
gSqNormDecDouble :: Fact m => Tagged m (CT' m Double -> Double)
gSqNormDecDouble = return $ (!0) . unCT' . unsafePerformIO . withBasicArgs dnormDouble
gSqNormDecInt64 :: Fact m => Tagged m (CT' m Int64 -> Int64)
gSqNormDecInt64 = return $ (!0) . unCT' . unsafePerformIO . withBasicArgs dnormInt64
ctCRTZq :: (Fact m, Reflects q Int64, CRTrans mon (ZqBasic q Int64))
=> TaggedT m mon (CT' m (ZqBasic q Int64) -> CT' m (ZqBasic q Int64))
ctCRTZq = do
ru' <- ru
return $ \x -> unsafePerformIO $
withPtrArray ru' (flip withBasicArgs x . dcrtZq)
ctCRTInvZq :: (Fact m, Reflects q Int64, CRTrans mon (ZqBasic q Int64))
=> TaggedT m mon (CT' m (ZqBasic q Int64) -> CT' m (ZqBasic q Int64))
ctCRTInvZq = do
mhatInv <- snd <$> crtInfo
ruinv' <- ruInv
return $ \x -> unsafePerformIO $
withPtrArray ruinv' (\ruptr -> with mhatInv (flip withBasicArgs x . dcrtinvZq ruptr))
ctCRTC :: (CRTrans mon (Complex Double), Fact m)
=> TaggedT m mon (CT' m (Complex Double) -> CT' m (Complex Double))
ctCRTC = do
ru' <- ru
return $ \x -> unsafePerformIO $
withPtrArray ru' (flip withBasicArgs x . dcrtC)
ctCRTInvC :: (CRTrans mon (Complex Double), Fact m)
=> TaggedT m mon (CT' m (Complex Double) -> CT' m (Complex Double))
ctCRTInvC = do
mhatInv <- snd <$> crtInfo
ruinv' <- ruInv
return $ \x -> unsafePerformIO $
withPtrArray ruinv' (\ruptr -> with mhatInv (flip withBasicArgs x . dcrtinvC ruptr))
cZipDispatch :: forall m r . (Storable r, Fact m)
=> (Ptr r -> Ptr r -> Int64 -> IO ())
-> Tagged m (CT' m r -> CT' m r -> CT' m r)
cZipDispatch f = do
let totm = fromIntegral $ totientFact @m
return $ coerce $ \a b -> unsafePerformIO $ do
yout <- SV.thaw a
SM.unsafeWith yout (\pout ->
SV.unsafeWith b (\pin ->
f pout pin totm))
unsafeFreeze yout
cDispatchGaussianDouble :: forall m var rnd .
(Fact m, ToRational var, MonadRandom rnd)
=> var -> rnd (CT' m Double)
cDispatchGaussianDouble var = flip proxyT (Proxy::Proxy m) $ do
let totm = totientFact @m
mval = valueFact @m
rad = radicalFact @m
ru' <- mapTaggedT (return . runIdentity) ru
yin <- T.lift $ realGaussians (var * fromIntegral (mval `div` rad)) totm
return $ unsafePerformIO $
withPtrArray ru' (\ruptr -> withBasicArgs (dgaussdecDouble ruptr) (CT' yin))
instance (Storable r, Random r, Fact m) => Random (CT' m r) where
{-# SPECIALIZE instance Fact m => Random (CT' m Int64) #-}
{-# SPECIALIZE instance Fact m => Random (CT' m Double) #-}
{-# SPECIALIZE instance Fact m => Random (CT' m (Complex Double)) #-}
{-# SPECIALIZE instance (Fact m, Reflects q Int64) => Random (CT' m (ZqBasic q Int64)) #-}
{-# SPECIALIZE instance Fact m => Random (CT' m (RRq q Double)) #-}
random = runRand $ replM (liftRand random)
randomR (CT' a, CT' b) = runRand $ CT' . SV.fromList <$> l
where l = zipWithM (\x y -> liftRand $ randomR (x, y)) (SV.toList a) (SV.toList b)
{-# INLINABLE random #-}
{-# INLINABLE randomR #-}
instance (Storable r, Random r, Fact m) => Random (CT m r) where
{-# SPECIALIZE instance Fact m => Random (CT m Int64) #-}
{-# SPECIALIZE instance Fact m => Random (CT m Double) #-}
{-# SPECIALIZE instance Fact m => Random (CT m (Complex Double)) #-}
{-# SPECIALIZE instance (Fact m, Reflects q Int64) => Random (CT m (ZqBasic q Int64)) #-}
{-# SPECIALIZE instance Fact m => Random (CT m (RRq q Double)) #-}
random = runRand $ CT <$> liftRand random
randomR (CT a, CT b) = runRand $ CT <$> liftRand (randomR (a, b))
randomR (a, b) = randomR (toCT a, toCT b)
{-# INLINABLE random #-}
{-# INLINABLE randomR #-}
instance (NFData r) => NFData (CT m r) where
{-# SPECIALIZE instance NFData (CT m Int64) #-}
{-# SPECIALIZE instance NFData (CT m Double) #-}
{-# SPECIALIZE instance NFData (CT m (Complex Double)) #-}
{-# SPECIALIZE instance NFData (CT m (ZqBasic q Int64)) #-}
{-# SPECIALIZE instance NFData (CT m (RRq q Double)) #-}
rnf (CT v) = rnf v
rnf (ZV v) = rnf v
instance (Protoable (IZipVector m r), Fact m, Storable r)
=> Protoable (CT m r) where
type ProtoType (CT m r) = ProtoType (IZipVector m r)
toProto x@(CT _) = toProto $ toZV x
toProto (ZV x) = toProto x
fromProto x = toCT . ZV <$> fromProto x
{-# INLINE repl #-}
repl :: forall m r . (Fact m, Storable r) => r -> CT' m r
repl = let n = totientFact @m
in coerce . SV.replicate n
replM :: forall m r mon . (Fact m, Storable r, Monad mon)
=> mon r -> mon (CT' m r)
replM = let n = totientFact @m
in fmap coerce . SV.replicateM n
{-# INLINE scalarPow' #-}
scalarPow' :: forall m r . (Fact m, Additive r, Storable r) => r -> CT' m r
scalarPow' =
let n = totientFact @m
in \r -> CT' $ generate n (\i -> if i == 0 then r else zero)
{-# INLINE ru #-}
{-# INLINE ruInv #-}
ru, ruInv :: forall m mon r . (CRTrans mon r, Fact m, Storable r)
=> TaggedT m mon [Vector r]
ru = do
wPow <- fst <$> crtInfo
let go (p,e) = let pp=p^e
in generate pp (wPow . (* (valueFact @m `div` pp)))
return $ go <$> ppsFact @m
ruInv = do
wPow <- fst <$> crtInfo
let go (p,e) = let pp=p^e
in generate pp (wPow . (* (valueFact @m `div` pp)) . negate)
return $ go <$> ppsFact @m
wrapVector :: forall m r . (Fact m, Ring r, Storable r) => Kron r -> CT' m r
wrapVector v = CT' $ generate (totientFact @m) (flip (indexK v) 0)
{-# INLINE gCRT #-}
{-# INLINE gInvCRT #-}
gCRT, gInvCRT :: forall m mon r . (Storable r, CRTrans mon r, Fact m) => mon (CT' m r)
gCRT = wrapVector <$> gCRTK @m
gInvCRT = wrapVector <$> gInvCRTK @m