{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Crypto.Lol.Cyclotomic.Tensor.CPP.Extension
( embedPow', embedDec', embedCRT'
, twacePowDec', twaceCRT'
, coeffs', powBasisPow'
, crtSetDec'
, backpermute'
) where
import Crypto.Lol.CRTrans
import Crypto.Lol.Cyclotomic.Tensor as T
import Crypto.Lol.Cyclotomic.Tensor.CPP.Instances ()
import Crypto.Lol.Prelude as LP hiding (lift, null)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.FiniteField
import Crypto.Lol.Types.Unsafe.ZqBasic hiding (ZqB, unZqB)
import Crypto.Lol.Types.ZmStar
import Control.Applicative hiding (empty)
import Control.Monad.Trans (lift)
import Data.Maybe
import Data.Reflection (reify)
import qualified Data.Vector as V
import Data.Vector.Storable as SV
import qualified Data.Vector.Unboxed as U
backpermute' :: (Storable a) =>
U.Vector Int
-> Vector a
-> Vector a
{-# INLINABLE backpermute' #-}
backpermute' is v = generate (U.length is) (\i -> v ! (is U.! i))
embedPow', embedDec' :: forall m m' r . (Additive r, Storable r, m `Divides` m')
=> Tagged '(m, m') (Vector r -> Vector r)
{-# INLINABLE embedPow' #-}
{-# INLINABLE embedDec' #-}
embedPow' = tag $ (\indices arr -> generate (U.length indices) $ \idx ->
let (j0,j1) = indices U.! idx
in if j0 == 0
then arr ! j1
else zero) $ baseIndicesPow @m @m'
embedDec' = tag $ (\indices arr -> generate (U.length indices)
(\idx -> maybe LP.zero
(\(sh,b) -> if b then negate (arr ! sh) else arr ! sh)
(indices U.! idx))) $ baseIndicesDec @m @m'
embedCRT' :: forall m m' mon r . (CRTrans mon r, Storable r, m `Divides` m')
=> TaggedT '(m, m') mon (Vector r -> Vector r)
embedCRT' =
lift (proxyT crtInfo (Proxy::Proxy m') :: mon (CRTInfo r)) >>
tagT (pure $ backpermute' $ baseIndicesCRT @m @m')
coeffs' :: forall m m' r . (Storable r, m `Divides` m')
=> Tagged '(m, m') (Vector r -> [Vector r])
coeffs' = tag $ flip (\x -> V.toList . V.map (`backpermute'` x))
$ extIndicesCoeffs @m @m'
twacePowDec' :: forall m m' r . (Storable r, m `Divides` m')
=> Tagged '(m, m') (Vector r -> Vector r)
{-# INLINABLE twacePowDec' #-}
twacePowDec' = tag $ backpermute' $ extIndicesPowDec @m @m'
kronToVec :: forall m r . (Fact m, Ring r, Storable r) => Kron r -> Vector r
kronToVec v = generate (totientFact @m) (flip (indexK v) 0)
twaceCRT' :: forall mon m m' r .
(Storable r, CRTrans mon r, m `Divides` m')
=> TaggedT '(m, m') mon (Vector r -> Vector r)
{-# INLINABLE twaceCRT' #-}
twaceCRT' = tagT $ do
g' <- kronToVec @m' <$> gCRTK @m'
gInv <- kronToVec @m <$> gInvCRTK @m
embed <- untagT $ embedCRT' @m @m'
(_, m'hatinv) <- proxyT crtInfo (Proxy::Proxy m')
let phi = totientFact @m
phi' = totientFact @m'
mhat = fromIntegral $ valueHatFact @m
hatRatioInv = m'hatinv * mhat
reltot = phi' `div` phi
tweak = SV.map (* hatRatioInv) $ SV.zipWith (*) (embed gInv) g'
indices = extIndicesCRT @m @m'
return $ \ arr ->
let v = backpermute' indices (SV.zipWith (*) tweak arr)
in generate phi $ \i -> foldl1' (+) $ SV.unsafeSlice (i*reltot) reltot v
powBasisPow' :: forall m m' r . (m `Divides` m', Ring r, SV.Storable r)
=> Tagged '(m, m') [SV.Vector r]
powBasisPow' = do
let (_, phi, phi', _) = indexInfo @m @m'
idxs = baseIndicesPow @m @m'
return $ LP.map (\k -> generate phi' $ \j ->
let (j0,j1) = idxs U.! j
in if j0==k && j1==0 then one else zero)
[0..phi' `div` phi - 1]
crtSetDec' :: forall m m' p .
(m `Divides` m', Prime p, Coprime (PToF p) m',
Reflects p Int64, IrreduciblePoly (ZqBasic p Int64))
=> Tagged '(m, m') [SV.Vector (ZqBasic p Int64)]
{-# INLINABLE crtSetDec' #-}
crtSetDec' =
let p = valuePrime @p
phi = totientFact @m'
d = order @m' p
h :: Int = valueHatFact @m'
hinv = recip $ fromIntegral h
in reify d $ \(_::Proxy d) -> do
let twCRTs' :: Kron (GF (ZqBasic p Int64) d)
= fromMaybe (error "internal error: crtSetDec': twCRTs") $ twCRTs @m'
zmsToIdx = T.zmsToIndexFact @m'
elt j i = indexK twCRTs' j (zmsToIdx i)
cosets = partitionCosets @m @m' p
return $ LP.map (\is -> generate phi
(\j -> hinv * trace
(LP.sum $ LP.map (elt j) is))) cosets