module Crypto.Lol.Cyclotomic.Tensor.CTensor (CT) where
import Algebra.Additive as Additive (C)
import Algebra.Module as Module (C)
import Algebra.ZeroTestable as ZeroTestable (C)
import Control.Applicative hiding ((*>))
import Control.Arrow ((***))
import Control.DeepSeq
import Control.Monad.Except
import Control.Monad.Identity (Identity (..), runIdentity)
import Control.Monad.Random
import Control.Monad.Trans as T (lift)
import Data.Coerce
import Data.Constraint hiding ((***))
import Data.Int
import Data.Maybe
import Data.Traversable as T
import Data.Vector.Generic as V (fromList, toList, unzip)
import Data.Vector.Storable as SV (Vector, convert, foldl',
fromList, generate,
length, map, replicate,
replicateM, thaw, thaw, toList,
unsafeFreeze,
unsafeWith, zipWith, (!))
import Data.Vector.Storable.Mutable as SM hiding (replicate)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr
import Test.QuickCheck hiding (generate)
import Crypto.Lol.CRTrans
import Crypto.Lol.Cyclotomic.Tensor
import Crypto.Lol.Cyclotomic.Tensor.CTensor.Backend
import Crypto.Lol.Cyclotomic.Tensor.CTensor.Extension
import Crypto.Lol.GaussRandom
import Crypto.Lol.Prelude as LP hiding
(replicate,
unzip, zip)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.FiniteField
import Crypto.Lol.Types.IZipVector
import Crypto.Lol.Types.Proto
import Crypto.Lol.Types.RRq
import Crypto.Lol.Types.ZqBasic
import Crypto.Proto.RLWE.Kq
import Crypto.Proto.RLWE.Rq
import Data.Foldable as F
import Data.Sequence as S (fromList)
import System.IO.Unsafe (unsafePerformIO)
newtype CT' (m :: Factored) r = CT' { unCT :: Vector r }
deriving (Show, Eq, NFData)
type role CT' representational nominal
data CT (m :: Factored) r where
CT :: Storable r => CT' m r -> CT m r
ZV :: IZipVector m r -> CT m r
deriving instance Show r => Show (CT m r)
instance Eq r => Eq (CT m r) where
(ZV x) == (ZV y) = x == y
(CT x) == (CT y) = x == y
x@(CT _) == y = x == toCT y
y == x@(CT _) = x == toCT y
instance (Fact m, Reflects q Int64) => Protoable (CT m (ZqBasic q Int64)) where
type ProtoType (CT m (ZqBasic q Int64)) = Rq
toProto (CT (CT' xs)) =
let m = fromIntegral $ proxy valueFact (Proxy::Proxy m)
q = proxy value (Proxy::Proxy q) :: Int64
in Rq m (fromIntegral q) $ S.fromList $ SV.toList $ SV.map LP.lift xs
toProto x@(ZV _) = toProto $ toCT x
fromProto (Rq m' q' xs) =
let m = proxy valueFact (Proxy::Proxy m) :: Int
q = proxy value (Proxy::Proxy q) :: Int64
n = proxy totientFact (Proxy::Proxy m)
xs' = SV.fromList $ F.toList xs
len = F.length xs
in if m == fromIntegral m' && len == n && fromIntegral q == q'
then return $ CT $ CT' $ SV.map reduce xs'
else throwError $
"An error occurred while reading the proto type for CT.\n\
\Expected m=" ++ show m ++ ", got " ++ show m' ++ "\n\
\Expected n=" ++ show n ++ ", got " ++ show len ++ "\n\
\Expected q=" ++ show q ++ ", got " ++ show q' ++ "."
instance (Fact m, Reflects q Double) => Protoable (CT m (RRq q Double)) where
type ProtoType (CT m (RRq q Double)) = Kq
toProto (CT (CT' xs)) =
let m = fromIntegral $ proxy valueFact (Proxy::Proxy m)
q = round (proxy value (Proxy::Proxy q) :: Double)
in Kq m q $ S.fromList $ SV.toList $ SV.map LP.lift xs
toProto x@(ZV _) = toProto $ toCT x
fromProto (Kq m' q' xs) =
let m = proxy valueFact (Proxy::Proxy m) :: Int
q = round (proxy value (Proxy::Proxy q) :: Double)
n = proxy totientFact (Proxy::Proxy m)
xs' = SV.fromList $ F.toList xs
len = F.length xs
in if m == fromIntegral m' && len == n && q == q'
then return $ CT $ CT' $ SV.map reduce xs'
else throwError $
"An error occurred while reading the proto type for CT.\n\
\Expected m=" ++ show m ++ ", got " ++ show m' ++ "\n\
\Expected n=" ++ show n ++ ", got " ++ show len ++ "\n\
\Expected q=" ++ show q ++ ", got " ++ show q' ++ "."
toCT :: (Storable r) => CT m r -> CT m r
toCT v@(CT _) = v
toCT (ZV v) = CT $ zvToCT' v
toZV :: (Fact m) => CT m r -> CT m r
toZV (CT (CT' v)) = ZV $ fromMaybe (error "toZV: internal error") $
iZipVector $ convert v
toZV v@(ZV _) = v
zvToCT' :: forall m r . (Storable r) => IZipVector m r -> CT' m r
zvToCT' v = coerce (convert $ unIZipVector v :: Vector r)
wrap :: (Storable s, Storable r) => (CT' l s -> CT' m r) -> (CT l s -> CT m r)
wrap f (CT v) = CT $ f v
wrap f (ZV v) = CT $ f $ zvToCT' v
wrapM :: (Storable s, Storable r, Monad mon) => (CT' l s -> mon (CT' m r))
-> (CT l s -> mon (CT m r))
wrapM f (CT v) = CT <$> f v
wrapM f (ZV v) = CT <$> f (zvToCT' v)
type family Tw (r :: *) :: * where
Tw (CT' m' r -> CT' m r) = Tagged '(m,m') (Vector r -> Vector r)
Tw (Maybe (CT' m' r -> CT' m r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)
type family Em r where
Em (CT' m r -> CT' m' r) = Tagged '(m,m') (Vector r -> Vector r)
Em (Maybe (CT' m r -> CT' m' r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)
instance (Additive r, Storable r, Fact m)
=> Additive.C (CT m r) where
(CT (CT' a)) + (CT (CT' b)) = CT $ CT' $ SV.zipWith (+) a b
a + b = toCT a + toCT b
negate (CT (CT' a)) = CT $ CT' $ SV.map negate a
negate a = negate $ toCT a
zero = CT $ repl zero
instance (ZeroTestable r, Storable r)
=> ZeroTestable.C (CT m r) where
isZero (CT (CT' a)) = SV.foldl' (\ b x -> b && isZero x) True a
isZero (ZV v) = isZero v
instance (GFCtx fp d, Fact m, Additive (CT m fp))
=> Module.C (GF fp d) (CT m fp) where
r *> v = case v of
CT (CT' arr) -> CT $ CT' $ SV.fromList $ unCoeffs $ r *> Coeffs $ SV.toList arr
ZV zv -> ZV $ fromJust $ iZipVector $ V.fromList $ unCoeffs $ r *> Coeffs $ V.toList $ unIZipVector zv
instance Fact m => Functor (CT m) where
fmap f x = pure f <*> x
instance Fact m => Applicative (CT m) where
pure = ZV . pure
(ZV f) <*> (ZV a) = ZV (f <*> a)
f@(ZV _) <*> v@(CT _) = f <*> toZV v
instance Fact m => Foldable (CT m) where
foldMap = foldMapDefault
instance Fact m => Traversable (CT m) where
traverse f r@(CT _) = T.traverse f $ toZV r
traverse f (ZV v) = ZV <$> T.traverse f v
instance Tensor CT where
type TElt CT r = (Storable r, Dispatch r)
entailIndexT = tag $ Sub Dict
entailEqT = tag $ Sub Dict
entailZTT = tag $ Sub Dict
entailNFDataT = tag $ Sub Dict
entailRandomT = tag $ Sub Dict
entailShowT = tag $ Sub Dict
entailModuleT = tag $ Sub Dict
scalarPow = CT . scalarPow'
l = wrap $ basicDispatch dl
lInv = wrap $ basicDispatch dlinv
mulGPow = wrap $ basicDispatch dmulgpow
mulGDec = wrap $ basicDispatch dmulgdec
divGPow = wrapM $ dispatchGInv dginvpow
divGDec = wrapM $ dispatchGInv dginvdec
crtFuncs = (,,,,) <$>
return (CT . repl) <*>
(wrap . untag (cZipDispatch dmul) <$> gCRT) <*>
(wrap . untag (cZipDispatch dmul) <$> gInvCRT) <*>
(wrap <$> untagT ctCRT) <*>
(wrap <$> untagT ctCRTInv)
twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
embedPow = wrap $ runIdentity $ coerceEm embedPow'
embedDec = wrap $ runIdentity $ coerceEm embedDec'
tGaussianDec v = CT <$> cDispatchGaussian v
gSqNormDec (CT v) = untag gSqNormDec' v
gSqNormDec (ZV v) = gSqNormDec (CT $ zvToCT' v)
crtExtFuncs = (,) <$> (wrap <$> coerceTw twaceCRT')
<*> (wrap <$> coerceEm embedCRT')
coeffs = wrapM $ coerceCoeffs coeffs'
powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'
crtSetDec = (CT <$>) <$> coerceBasis crtSetDec'
fmapT f = wrap $ coerce (SV.map f)
zipWithT f v1' v2' =
let (CT (CT' v1)) = toCT v1'
(CT (CT' v2)) = toCT v2'
in CT $ CT' $ SV.zipWith f v1 v2
unzipT v =
let (CT (CT' x)) = toCT v
in (CT . CT') *** (CT . CT') $ unzip x
coerceTw :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m' r -> CT' m r)
coerceTw = (coerce <$>) . untagT
coerceEm :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m r -> CT' m' r)
coerceEm = (coerce <$>) . untagT
coerceCoeffs :: Tagged '(m,m') (Vector r -> [Vector r]) -> CT' m' r -> [CT' m r]
coerceCoeffs = coerce
coerceBasis :: Tagged '(m,m') [Vector r] -> Tagged m [CT' m' r]
coerceBasis = coerce
dispatchGInv :: forall m r . (Storable r, Fact m)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16)
-> CT' m r -> Maybe (CT' m r)
dispatchGInv f =
let factors = proxy (marshalFactors <$> ppsFact) (Proxy::Proxy m)
totm = proxy (fromIntegral <$> totientFact) (Proxy::Proxy m)
numFacts = fromIntegral $ SV.length factors
in \(CT' x) -> unsafePerformIO $ do
yout <- SV.thaw x
ret <- SM.unsafeWith yout (\pout ->
SV.unsafeWith factors (\pfac ->
f pout totm pfac numFacts))
if ret /= 0
then Just . CT' <$> unsafeFreeze yout
else return Nothing
withBasicArgs :: forall m r . (Fact m, Storable r)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
-> CT' m r -> IO (CT' m r)
withBasicArgs f =
let factors = proxy (marshalFactors <$> ppsFact) (Proxy::Proxy m)
totm = proxy (fromIntegral <$> totientFact) (Proxy::Proxy m)
numFacts = fromIntegral $ SV.length factors
in \(CT' x) -> do
yout <- SV.thaw x
SM.unsafeWith yout (\pout ->
SV.unsafeWith factors (\pfac ->
f pout totm pfac numFacts))
CT' <$> unsafeFreeze yout
basicDispatch :: (Storable r, Fact m)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
-> CT' m r -> CT' m r
basicDispatch f = unsafePerformIO . withBasicArgs f
gSqNormDec' :: (Storable r, Fact m, Dispatch r)
=> Tagged m (CT' m r -> r)
gSqNormDec' = return $ (!0) . unCT . unsafePerformIO . withBasicArgs dnorm
ctCRT :: (Storable r, CRTrans mon r, Dispatch r, Fact m)
=> TaggedT m mon (CT' m r -> CT' m r)
ctCRT = do
ru' <- ru
return $ \x -> unsafePerformIO $
withPtrArray ru' (flip withBasicArgs x . dcrt)
ctCRTInv :: (Storable r, CRTrans mon r, Dispatch r, Fact m)
=> TaggedT m mon (CT' m r -> CT' m r)
ctCRTInv = do
mhatInv <- snd <$> crtInfo
ruinv' <- ruInv
return $ \x -> unsafePerformIO $
withPtrArray ruinv' (\ruptr -> with mhatInv (flip withBasicArgs x . dcrtinv ruptr))
cZipDispatch :: (Storable r, Fact m)
=> (Ptr r -> Ptr r -> Int64 -> IO ())
-> Tagged m (CT' m r -> CT' m r -> CT' m r)
cZipDispatch f = do
totm <- fromIntegral <$> totientFact
return $ coerce $ \a b -> unsafePerformIO $ do
yout <- SV.thaw a
SM.unsafeWith yout (\pout ->
SV.unsafeWith b (\pin ->
f pout pin totm))
unsafeFreeze yout
cDispatchGaussian :: forall m r var rnd .
(Storable r, Transcendental r, Dispatch r, Ord r,
Fact m, ToRational var, Random r, MonadRandom rnd)
=> var -> rnd (CT' m r)
cDispatchGaussian var = flip proxyT (Proxy::Proxy m) $ do
ruinv' <- mapTaggedT (return . fromMaybe (error "complexGaussianRoots")) ru
totm <- pureT totientFact
m <- pureT valueFact
rad <- pureT radicalFact
yin <- T.lift $ realGaussians (var * fromIntegral (m `div` rad)) totm
return $ unsafePerformIO $
withPtrArray ruinv' (\ruptr -> withBasicArgs (dgaussdec ruptr) (CT' yin))
instance (Arbitrary r, Fact m, Storable r) => Arbitrary (CT' m r) where
arbitrary = replM arbitrary
shrink = shrinkNothing
instance (Storable r, Arbitrary (CT' m r)) => Arbitrary (CT m r) where
arbitrary = CT <$> arbitrary
instance (Storable r, Random r, Fact m) => Random (CT' m r) where
random = runRand $ replM (liftRand random)
randomR = error "randomR nonsensical for CT'"
instance (Storable r, Random (CT' m r)) => Random (CT m r) where
random = runRand $ CT <$> liftRand random
randomR = error "randomR nonsensical for CT"
instance (NFData r) => NFData (CT m r) where
rnf (CT v) = rnf v
rnf (ZV v) = rnf v
repl :: forall m r . (Fact m, Storable r) => r -> CT' m r
repl = let n = proxy totientFact (Proxy::Proxy m)
in coerce . SV.replicate n
replM :: forall m r mon . (Fact m, Storable r, Monad mon)
=> mon r -> mon (CT' m r)
replM = let n = proxy totientFact (Proxy::Proxy m)
in fmap coerce . SV.replicateM n
scalarPow' :: forall m r . (Fact m, Additive r, Storable r) => r -> CT' m r
scalarPow' =
let n = proxy totientFact (Proxy::Proxy m)
in \r -> CT' $ generate n (\i -> if i == 0 then r else zero)
ru, ruInv :: (CRTrans mon r, Fact m, Storable r)
=> TaggedT m mon [Vector r]
ru = do
mval <- pureT valueFact
wPow <- fst <$> crtInfo
LP.map
(\(p,e) -> do
let pp = p^e
pow = mval `div` pp
generate pp (wPow . (*pow))) <$>
pureT ppsFact
ruInv = do
mval <- pureT valueFact
wPow <- fst <$> crtInfo
LP.map
(\(p,e) -> do
let pp = p^e
pow = mval `div` pp
generate pp (\i -> wPow $ i*pow)) <$>
pureT ppsFact
wrapVector :: forall mon m r . (Monad mon, Fact m, Ring r, Storable r)
=> TaggedT m mon (Kron r) -> mon (CT' m r)
wrapVector v = do
vmat <- proxyT v (Proxy::Proxy m)
let n = proxy totientFact (Proxy::Proxy m)
return $ CT' $ generate n (flip (indexK vmat) 0)
gCRT, gInvCRT :: (Storable r, CRTrans mon r, Fact m)
=> mon (CT' m r)
gCRT = wrapVector gCRTK
gInvCRT = wrapVector gInvCRTK