module Crypto.Lol.Cyclotomic.Tensor.CPP.Backend
( Dispatch
, dcrt, dcrtinv
, dgaussdec
, dl, dlinv
, dnorm
, dmulgpow, dmulgdec
, dginvpow, dginvdec
, dmul
, marshalFactors
, CPP
, withArray, withPtrArray
) where
import Crypto.Lol.Prelude as LP (Complex, PP, Proxy (..), Tagged,
map, mapM_, proxy, tag)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.Unsafe.RRq
import Crypto.Lol.Types.Unsafe.ZqBasic
import Data.Int
import Data.Vector.Storable as SV (Vector, fromList,
unsafeToForeignPtr0)
import Data.Vector.Storable.Internal (getPtr)
import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.Marshal.Array (withArray)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (Storable (..))
import GHC.TypeLits
marshalFactors :: [PP] -> Vector CPP
marshalFactors = SV.fromList . LP.map (\(p,e) -> (fromIntegral p, fromIntegral e))
withPtrArray :: (Storable a) => [Vector a] -> (Ptr (Ptr a) -> IO b) -> IO b
withPtrArray v f = do
let vs = LP.map SV.unsafeToForeignPtr0 v
ptrV = LP.map (\(fp,_) -> getPtr fp) vs
res <- withArray ptrV f
LP.mapM_ (\(fp,_) -> touchForeignPtr fp) vs
return res
type CPP = (Int16, Int16)
instance (Storable a, Storable b)
=> Storable (a,b) where
sizeOf _ = sizeOf (undefined :: a) + sizeOf (undefined :: b)
alignment _ = max (alignment (undefined :: a)) (alignment (undefined :: b))
peek p = do
a <- peek (castPtr p :: Ptr a)
b <- peek (castPtr (plusPtr p (sizeOf a)) :: Ptr b)
return (a,b)
poke p (a,b) = do
poke (castPtr p :: Ptr a) a
poke (castPtr (plusPtr p (sizeOf a)) :: Ptr b) b
data ZqB64D
data ComplexD
data DoubleD
data Int64D
data RRqD
type family CTypeOf x where
CTypeOf (a,b) = EqCType a b (CTypeOf a) (CTypeOf b)
CTypeOf (ZqBasic (q :: k) Int64) = ZqB64D
CTypeOf Double = DoubleD
CTypeOf Int64 = Int64D
CTypeOf (Complex Double) = ComplexD
CTypeOf (RRq (q :: k) Double) = RRqD
CTypeOf (ZqBasic (q :: k) i) = TypeError (Text "Unsupported C type: " :<>: ShowType (ZqBasic q i) :$$: Text "Use Int64 as the base ring")
CTypeOf (Complex i) = TypeError (Text "Unsupported C type: " :<>: ShowType (Complex i) :$$: Text "Use Double as the base ring")
CTypeOf (RRq (q :: k) i) = TypeError (Text "Unsupported C type: " :<>: ShowType (RRq q i) :$$: Text "Use Double as the base ring")
CTypeOf a = TypeError (Text "Unsupported C type: " :<>: ShowType a)
type family EqCType a b c d where
EqCType a b ZqB64D ZqB64D = ZqB64D
EqCType a b RRqD RRqD = RRqD
EqCType a b ComplexD ComplexD = ComplexD
EqCType a b c c = TypeError (Text "Cannot call C code on a tuple of type " :<>: ShowType a)
EqCType a b c d = TypeError (Text "You are trying to use CTensor on a tuple," :<>:
Text " but the tuple contains two different C types: " :$$:
ShowType a :<>: Text " and " :<>: ShowType b)
class (Tuple a) => ZqTuple a where
type ModPairs a
getModuli :: Tagged a (ModPairs a)
instance (Reflects q Int64) => ZqTuple (ZqBasic q Int64) where
type ModPairs (ZqBasic q Int64) = Int64
getModuli = tag $ proxy value (Proxy::Proxy q)
instance (Reflects q r, RealFrac r) => ZqTuple (RRq q r) where
type ModPairs (RRq q r) = Int64
getModuli = tag $ round (proxy value (Proxy::Proxy q) :: r)
instance (ZqTuple a, ZqTuple b) => ZqTuple (a, b) where
type ModPairs (a,b) = (ModPairs a, ModPairs b)
getModuli =
let as = proxy getModuli (Proxy::Proxy a)
bs = proxy getModuli (Proxy :: Proxy b)
in tag (as,bs)
class Tuple a where
numComponents :: Tagged a Int16
instance Tuple a where
numComponents = tag 1
instance (Tuple a, Tuple b) => Tuple (a,b) where
numComponents = tag $ proxy numComponents (Proxy::Proxy a) + proxy numComponents (Proxy::Proxy b)
type Dispatch r = (Dispatch' (CTypeOf r) r)
class (repr ~ CTypeOf r) => Dispatch' repr r where
dcrt :: Ptr (Ptr r) -> Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dcrtinv :: Ptr (Ptr r) -> Ptr r -> Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dgaussdec :: Ptr (Ptr (Complex r)) -> Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dl :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlinv :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dnorm :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dginvpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dmul :: Ptr r -> Ptr r -> Int64 -> IO ()
instance (ZqTuple r, Storable (ModPairs r), CTypeOf r ~ RRqD)
=> Dispatch' RRqD r where
dcrt = error "cannot call CT CRT on type RRq"
dcrtinv = error "cannot call CT CRTInv on type RRq"
dl = error "cannot call CT L on type RRq (though you probably should be able to)"
dlinv = error "cannot call CT LInv on type RRq (though you probably should be able to)"
dnorm = error "cannto call CT normSq on type RRq"
dmulgpow = error "cannot call CT mulGPow on type RRq"
dmulgdec = error "cannot call CT mulGDec on type RRq"
dginvpow = error "cannot call CT divGPow on type RRq"
dginvdec = error "cannot call CT divGDec on type RRq"
dmul = error "cannot call CT mul on type RRq"
dgaussdec = error "cannot call CT gaussianDec on type RRq"
instance (ZqTuple r, Storable (ModPairs r), CTypeOf r ~ ZqB64D)
=> Dispatch' ZqB64D r where
dcrt ruptr pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorCRTRq numPairs (castPtr pout) totm pfac numFacts (castPtr ruptr) (castPtr qsptr)
dcrtinv ruptr minv pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorCRTInvRq numPairs (castPtr pout) totm pfac numFacts (castPtr ruptr) (castPtr minv) (castPtr qsptr)
dl pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorLRq numPairs (castPtr pout) totm pfac numFacts (castPtr qsptr)
dlinv pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorLInvRq numPairs (castPtr pout) totm pfac numFacts (castPtr qsptr)
dnorm = error "cannot call CT normSq on type ZqBasic"
dmulgpow pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorGPowRq numPairs (castPtr pout) totm pfac numFacts (castPtr qsptr)
dmulgdec pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorGDecRq numPairs (castPtr pout) totm pfac numFacts (castPtr qsptr)
dginvpow pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorGInvPowRq numPairs (castPtr pout) totm pfac numFacts (castPtr qsptr)
dginvdec pout totm pfac numFacts =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
tensorGInvDecRq numPairs (castPtr pout) totm pfac numFacts (castPtr qsptr)
dmul aout bout totm =
let qs = proxy getModuli (Proxy::Proxy r)
numPairs = proxy numComponents (Proxy::Proxy r)
in with qs $ \qsptr ->
mulRq numPairs (castPtr aout) (castPtr bout) totm (castPtr qsptr)
dgaussdec = error "cannot call CT gaussianDec on type ZqBasic"
instance (Tuple r, CTypeOf r ~ ComplexD) => Dispatch' ComplexD r where
dcrt ruptr pout totm pfac numFacts =
tensorCRTC (proxy numComponents (Proxy::Proxy r)) (castPtr pout) totm pfac numFacts (castPtr ruptr)
dcrtinv ruptr minv pout totm pfac numFacts =
tensorCRTInvC (proxy numComponents (Proxy::Proxy r)) (castPtr pout) totm pfac numFacts (castPtr ruptr) (castPtr minv)
dl pout =
tensorLC (proxy numComponents (Proxy::Proxy r)) (castPtr pout)
dlinv pout =
tensorLInvC (proxy numComponents (Proxy::Proxy r)) (castPtr pout)
dnorm = error "cannot call CT normSq on type Complex Double"
dmulgpow pout =
tensorGPowC (proxy numComponents (Proxy::Proxy r)) (castPtr pout)
dmulgdec pout =
tensorGDecC (proxy numComponents (Proxy::Proxy r)) (castPtr pout)
dginvpow pout =
tensorGInvPowC (proxy numComponents (Proxy::Proxy r)) (castPtr pout)
dginvdec pout =
tensorGInvDecC (proxy numComponents (Proxy::Proxy r)) (castPtr pout)
dmul aout bout =
mulC (proxy numComponents (Proxy::Proxy r)) (castPtr aout) (castPtr bout)
dgaussdec = error "cannot call CT gaussianDec on type Comple Double"
instance Dispatch' DoubleD Double where
dcrt = error "cannot call CT Crt on type Double"
dcrtinv = error "cannot call CT CrtInv on type Double"
dl pout =
tensorLDouble 1 (castPtr pout)
dlinv pout =
tensorLInvDouble 1 (castPtr pout)
dnorm pout = tensorNormSqD 1 (castPtr pout)
dmulgpow = error "cannot call CT mulGPow on type Double"
dmulgdec = error "cannot call CT mulGDec on type Double"
dginvpow = error "cannot call CT divGPow on type Double"
dginvdec = error "cannot call CT divGDec on type Double"
dmul = error "cannot call CT (*) on type Double"
dgaussdec ruptr pout totm pfac numFacts =
tensorGaussianDec 1 (castPtr pout) totm pfac numFacts (castPtr ruptr)
instance Dispatch' Int64D Int64 where
dcrt = error "cannot call CT Crt on type Int64"
dcrtinv = error "cannot call CT CrtInv on type Int64"
dl pout =
tensorLR 1 (castPtr pout)
dlinv pout =
tensorLInvR 1 (castPtr pout)
dnorm pout =
tensorNormSqR 1 (castPtr pout)
dmulgpow pout =
tensorGPowR 1 (castPtr pout)
dmulgdec pout =
tensorGDecR 1 (castPtr pout)
dginvpow pout =
tensorGInvPowR 1 (castPtr pout)
dginvdec pout =
tensorGInvDecR 1 (castPtr pout)
dmul = error "cannot call CT (*) on type Int64"
dgaussdec = error "cannot call CT gaussianDec on type Int64"
foreign import ccall unsafe "tensorLR" tensorLR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorLInvR" tensorLInvR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorLRq" tensorLRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorLInvRq" tensorLInvRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorLDouble" tensorLDouble :: Int16 -> Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorLInvDouble" tensorLInvDouble :: Int16 -> Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorLC" tensorLC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorLInvC" tensorLInvC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorNormSqR" tensorNormSqR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorNormSqD" tensorNormSqD :: Int16 -> Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGPowR" tensorGPowR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGPowRq" tensorGPowRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorGPowC" tensorGPowC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGDecR" tensorGDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGDecRq" tensorGDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorGDecC" tensorGDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO Int16
foreign import ccall unsafe "tensorGInvPowC" tensorGInvPowC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO Int16
foreign import ccall unsafe "tensorGInvDecC" tensorGInvDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorCRTRq" tensorCRTRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorCRTC" tensorCRTC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()
foreign import ccall unsafe "tensorCRTInvRq" tensorCRTInvRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Ptr (ZqBasic q Int64) -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorCRTInvC" tensorCRTInvC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> Ptr (Complex Double) -> IO ()
foreign import ccall unsafe "tensorGaussianDec" tensorGaussianDec :: Int16 -> Ptr Double -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()
foreign import ccall unsafe "mulRq" mulRq :: Int16 -> Ptr (ZqBasic q Int64) -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "mulC" mulC :: Int16 -> Ptr (Complex Double) -> Ptr (Complex Double) -> Int64 -> IO ()