{-# LANGUAGE ConstraintKinds, DataKinds, DeriveDataTypeable, FlexibleContexts, FlexibleInstances, GADTs, InstanceSigs, MultiParamTypeClasses, NoImplicitPrelude, RebindableSyntax, RoleAnnotations, ScopedTypeVariables, StandaloneDeriving, TypeFamilies, TypeOperators, UndecidableInstances #-} -- | A pure, repa-based implementation of the Tensor interface. module Crypto.Lol.Cyclotomic.Tensor.RepaTensor ( RT ) where import Crypto.Lol.Cyclotomic.Tensor as T import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.CRT import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.Extension import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.Gauss import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.GL import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.RTCommon as RT import Crypto.Lol.LatticePrelude as LP hiding ((!!)) import Crypto.Lol.Types.IZipVector import Algebra.Additive as Additive (C) import Algebra.Ring as Ring (C) import Algebra.ZeroTestable as ZeroTestable (C) import Control.Applicative import Control.DeepSeq (NFData (rnf)) import Control.Monad (liftM) import Control.Monad.Random import Data.Coerce import Data.Constraint import Data.Foldable as F import Data.Maybe import Data.Traversable as T import Data.Typeable import Data.Vector.Unboxed as U hiding (force) import Test.QuickCheck -- | An implementation of 'Tensor' backed by repa. data RT (m :: Factored) r where RT :: Unbox r => !(Arr m r) -> RT m r ZV :: IZipVector m r -> RT m r deriving (Typeable) deriving instance Show r => Show (RT m r) instance Eq r => Eq (RT m r) where (ZV a) == (ZV b) = a == b (RT a) == (RT b) = a == b a@(RT _) == b = a == toRT b a == b@(RT _) = toRT a == b zvToArr :: Unbox r => IZipVector m r -> Arr m r zvToArr v = let vec = convert $ unIZipVector v in Arr $ fromUnboxed (Z :. U.length vec) vec -- converts to RT constructor toRT :: Unbox r => RT m r -> RT m r toRT v@(RT _) = v toRT (ZV v) = RT $ zvToArr v -- EAC: this does more work than is necessary, since any vector in RT m r should have length @m@. toZV :: Fact m => RT m r -> RT m r toZV (RT (Arr v)) = ZV $ fromMaybe (error "toZV: internal error") $ iZipVector $ convert $ toUnboxed v toZV v@(ZV _) = v wrap :: Unbox r => (Arr l r -> Arr m r) -> RT l r -> RT m r wrap f (RT v) = RT $ f v wrap f (ZV v) = RT $ f $ zvToArr v wrapM :: (Unbox r, Monad mon) => (Arr l r -> mon (Arr m r)) -> RT l r -> mon (RT m r) wrapM f (RT v) = liftM RT $ f v wrapM f (ZV v) = liftM RT $ f $ zvToArr v instance Tensor RT where type TElt RT r = (IntegralDomain r, ZeroTestable r, Eq r, Random r, NFData r, Unbox r, Elt r) entailIndexT = tag $ Sub Dict entailFullT = tag $ Sub Dict scalarPow = RT . scalarPow' l = wrap fL lInv = wrap fLInv mulGPow = wrap fGPow mulGDec = wrap fGDec divGPow = wrapM fGInvPow divGDec = wrapM fGInvDec crtFuncs = (,,,,) <$> (liftM (RT .) scalarCRT') <*> (wrap <$> mulGCRT') <*> (wrap <$> divGCRT') <*> (wrap <$> fCRT) <*> (wrap <$> fCRTInv) -- instance sigs are the cleanest way to handle many weird types -- coming up tGaussianDec :: forall v rnd m q . (Fact m, OrdFloat q, Random q, TElt RT q, ToRational v, MonadRandom rnd) => v -> rnd (RT m q) tGaussianDec = liftM RT . tGaussianDec' twacePowDec = wrap twacePowDec' embedPow = wrap embedPow' embedDec = wrap embedDec' crtExtFuncs = (,) <$> (liftM wrap twaceCRT') <*> (liftM wrap embedCRT') coeffs = wrapM coeffs' powBasisPow = (RT <$>) <$> powBasisPow' crtSetDec = (RT <$>) <$> crtSetDec' fmapT f (RT v) = RT $ (coerce $ force . RT.map f) v fmapT f v@(ZV _) = fmapT f $ toRT v -- Repa arrays don't have mapM, so apply to underlying Unboxed -- vector instead fmapTM f (RT (Arr arr)) = liftM (RT . Arr . fromUnboxed (extent arr)) $ U.mapM f $ toUnboxed arr fmapTM f v@(ZV _) = fmapTM f $ toRT v ---------- "Container" instances ---------- instance Fact m => Functor (RT m) where -- Functor instance is implied by Applicative fmap f x = pure f <*> x instance Fact m => Applicative (RT m) where pure = ZV . pure -- RT can never hold an a -> b (ZV f) <*> (ZV a) = ZV (f <*> a) f@(ZV _) <*> v@(RT _) = f <*> toZV v instance Fact m => Foldable (RT m) where -- Foldable instance is implied by Traversable foldMap = foldMapDefault instance Fact m => Traversable (RT m) where traverse f r@(RT _) = T.traverse f $ toZV r traverse f (ZV v) = ZV <$> T.traverse f v ---------- Numeric Prelude instances ---------- -- CJP: should Elt, Unbox be constraints on these instances? It's -- possible to zipWith on IZipVector, so it's not *necessary* to -- convert toRT. instance (Fact m, Additive r, Unbox r, Elt r) => Additive.C (RT m r) where (RT a) + (RT b) = RT $ coerce (\x -> force . RT.zipWith (+) x) a b a + b = toRT a + toRT b negate (RT a) = RT $ (coerce $ force . RT.map negate) a negate a = negate $ toRT a zero = RT $ repl zero instance (Fact m, Ring r, Unbox r, Elt r) => Ring.C (RT m r) where (RT a) * (RT b) = RT $ coerce (\x -> force . RT.zipWith (*) x) a b a * b = (toRT a) * (toRT b) fromInteger = RT . repl . fromInteger instance (Fact m, ZeroTestable r, Unbox r, Elt r) => ZeroTestable.C (RT m r) where -- not using 'zero' to avoid Additive r constraint isZero (RT (Arr a)) = isZero $ foldAllS (\ x y -> if isZero x then y else x) (a RT.! (Z:.0)) a isZero (ZV v) = isZero v ---------- Miscellaneous instances ---------- -- CJP: shouldn't these instances be defined in RTCommon, where the -- Arr data type is defined? Here they are orphans. instance (Unbox r, Random (Arr m r)) => Random (RT m r) where random = runRand $ liftM RT (liftRand random) randomR = error "randomR nonsensical for RT" instance (Unbox r, Arbitrary (Arr m r)) => Arbitrary (RT m r) where arbitrary = RT <$> arbitrary instance (NFData r) => NFData (RT m r) where rnf (RT v) = rnf v rnf (ZV v) = rnf v