{-# LANGUAGE ConstraintKinds, DataKinds, FlexibleContexts,
             FlexibleInstances, GADTs, GeneralizedNewtypeDeriving,
             InstanceSigs, MultiParamTypeClasses, NoImplicitPrelude,
             PolyKinds, RankNTypes, RebindableSyntax, RoleAnnotations,
             ScopedTypeVariables, StandaloneDeriving, TupleSections,
             TypeFamilies, TypeOperators, TypeSynonymInstances,
             UndecidableInstances #-}

-- | Wrapper for a C++ implementation of the 'Tensor' interface.

module Crypto.Lol.Cyclotomic.Tensor.CTensor
( CT ) where

import Algebra.Additive     as Additive (C)
import Algebra.Module       as Module (C)
import Algebra.ZeroTestable as ZeroTestable (C)

import Control.Applicative    hiding ((*>))
import Control.Arrow          ((***))
import Control.DeepSeq
import Control.Monad.Except
import Control.Monad.Identity (Identity (..), runIdentity)
import Control.Monad.Random
import Control.Monad.Trans    as T (lift)

import Data.Coerce
import Data.Constraint              hiding ((***))
import Data.Int
import Data.Maybe
import Data.Traversable             as T
import Data.Vector.Generic          as V (fromList, toList, unzip)
import Data.Vector.Storable         as SV (Vector, convert, foldl',
                                           foldl1', fromList, generate,
                                           length, map, mapM, replicate,
                                           replicateM, thaw, thaw, toList,
                                           unsafeFreeze, unsafeSlice,
                                           unsafeWith, zipWith, (!))
import Data.Vector.Storable.Mutable as SM hiding (replicate)

import Foreign.Marshal.Utils (with)
import Foreign.Ptr
import Test.QuickCheck       hiding (generate)

import Crypto.Lol.CRTrans
import Crypto.Lol.Cyclotomic.Tensor
import Crypto.Lol.Cyclotomic.Tensor.CTensor.Backend
import Crypto.Lol.Cyclotomic.Tensor.CTensor.Extension
import Crypto.Lol.GaussRandom
import Crypto.Lol.Prelude                             as LP hiding
                                                             (replicate,
                                                             unzip, zip)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.FiniteField
import Crypto.Lol.Types.IZipVector
import Crypto.Lol.Types.Proto
import Crypto.Lol.Types.RRq
import Crypto.Lol.Types.ZqBasic

import Crypto.Proto.RLWE.Kq
import Crypto.Proto.RLWE.Rq

import Data.Foldable as F
import Data.Sequence as S (fromList)

import System.IO.Unsafe (unsafePerformIO)

-- | Newtype wrapper around a Vector.
newtype CT' (m :: Factored) r = CT' { unCT :: Vector r }
                              deriving (Show, Eq, NFData)

-- the first argument, though phantom, affects representation
type role CT' representational nominal

-- GADT wrapper that distinguishes between Unbox and unrestricted
-- element types

-- | An implementation of 'Tensor' backed by C++ code.
data CT (m :: Factored) r where
  CT :: Storable r => CT' m r -> CT m r
  ZV :: IZipVector m r -> CT m r

deriving instance Show r => Show (CT m r)

instance Eq r => Eq (CT m r) where
  (ZV x) == (ZV y) = x == y
  (CT x) == (CT y) = x == y
  x@(CT _) == y = x == toCT y
  y == x@(CT _) = x == toCT y

instance (Fact m, Reflects q Int64) => Protoable (CT m (ZqBasic q Int64)) where
  type ProtoType (CT m (ZqBasic q Int64)) = Rq

  toProto (CT (CT' xs)) =
    let m = fromIntegral $ proxy valueFact (Proxy::Proxy m)
        q = proxy value (Proxy::Proxy q) :: Int64
    in Rq m (fromIntegral q) $ S.fromList $ SV.toList $ SV.map LP.lift xs
  toProto x@(ZV _) = toProto $ toCT x

  fromProto (Rq m' q' xs) =
    let m = proxy valueFact (Proxy::Proxy m) :: Int
        q = proxy value (Proxy::Proxy q) :: Int64
        n = proxy totientFact (Proxy::Proxy m)
        xs' = SV.fromList $ F.toList xs
        len = F.length xs
    in if m == fromIntegral m' && len == n && fromIntegral q == q'
       then return $ CT $ CT' $ SV.map reduce xs'
       else throwError $
            "An error occurred while reading the proto type for CT.\n\
            \Expected m=" ++ show m ++ ", got " ++ show m' ++ "\n\
            \Expected n=" ++ show n ++ ", got " ++ show len ++ "\n\
            \Expected q=" ++ show q ++ ", got " ++ show q' ++ "."

instance (Fact m, Reflects q Double) => Protoable (CT m (RRq q Double)) where
  type ProtoType (CT m (RRq q Double)) = Kq

  toProto (CT (CT' xs)) =
    let m = fromIntegral $ proxy valueFact (Proxy::Proxy m)
        q = proxy value (Proxy::Proxy q) :: Double
    in Kq m q $ S.fromList $ SV.toList $ SV.map LP.lift xs
  toProto x@(ZV _) = toProto $ toCT x

  fromProto (Kq m' q' xs) =
    let m = proxy valueFact (Proxy::Proxy m) :: Int
        q = proxy value (Proxy::Proxy q) :: Double
        n = proxy totientFact (Proxy::Proxy m)
        xs' = SV.fromList $ F.toList xs
        len = F.length xs
    in if m == fromIntegral m' && len == n && q == q'
       then return $ CT $ CT' $ SV.map reduce xs'
       else throwError $
            "An error occurred while reading the proto type for CT.\n\
            \Expected m=" ++ show m ++ ", got " ++ show m' ++ "\n\
            \Expected n=" ++ show n ++ ", got " ++ show len ++ "\n\
            \Expected q=" ++ show (round q :: Int64) ++ ", got " ++ show q' ++ "."

toCT :: (Storable r) => CT m r -> CT m r
toCT v@(CT _) = v
toCT (ZV v) = CT $ zvToCT' v

toZV :: (Fact m) => CT m r -> CT m r
toZV (CT (CT' v)) = ZV $ fromMaybe (error "toZV: internal error") $
                    iZipVector $ convert v
toZV v@(ZV _) = v

zvToCT' :: forall m r . (Storable r) => IZipVector m r -> CT' m r
zvToCT' v = coerce (convert $ unIZipVector v :: Vector r)

wrap :: (Storable r) => (CT' l r -> CT' m r) -> (CT l r -> CT m r)
wrap f (CT v) = CT $ f v
wrap f (ZV v) = CT $ f $ zvToCT' v

wrapM :: (Storable r, Monad mon) => (CT' l r -> mon (CT' m r))
         -> (CT l r -> mon (CT m r))
wrapM f (CT v) = CT <$> f v
wrapM f (ZV v) = CT <$> f (zvToCT' v)

-- convert an CT' *twace* signature to Tagged one
type family Tw (r :: *) :: * where
  Tw (CT' m' r -> CT' m r) = Tagged '(m,m') (Vector r -> Vector r)
  Tw (Maybe (CT' m' r -> CT' m r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)

type family Em r where
  Em (CT' m r -> CT' m' r) = Tagged '(m,m') (Vector r -> Vector r)
  Em (Maybe (CT' m r -> CT' m' r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)


---------- NUMERIC PRELUDE INSTANCES ----------

-- CJP: Additive, Ring are not necessary when we use zipWithT
-- EAC: This has performance implications for the CT backend,
--      which used a (very fast) C function for (*) and (+)
instance (Additive r, Storable r, Fact m, Dispatch r)
  => Additive.C (CT m r) where
  (CT (CT' a)) + (CT (CT' b)) = CT $ CT' $ SV.zipWith (+) a b
  a + b = toCT a + toCT b
  negate (CT (CT' a)) = CT $ CT' $ SV.map negate a -- EAC: This probably should be converted to C code
  negate a = negate $ toCT a

  zero = CT $ repl zero

{-
instance (Fact m, Ring r, Storable r, Dispatch r)
  => Ring.C (CT m r) where
  (CT a@(CT' _)) * (CT b@(CT' _)) = CT $ (untag $ cZipDispatch dmul) a b

  fromInteger = CT . repl . fromInteger
-}

instance (ZeroTestable r, Storable r, Fact m)
         => ZeroTestable.C (CT m r) where
  --{-# INLINABLE isZero #-}
  isZero (CT (CT' a)) = SV.foldl' (\ b x -> b && isZero x) True a
  isZero (ZV v) = isZero v

instance (GFCtx fp d, Fact m, Additive (CT m fp))
    => Module.C (GF fp d) (CT m fp) where

  r *> v = case v of
    CT (CT' arr) -> CT $ CT' $ SV.fromList $ unCoeffs $ r *> Coeffs $ SV.toList arr
    ZV zv -> ZV $ fromJust $ iZipVector $ V.fromList $ unCoeffs $ r *> Coeffs $ V.toList $ unIZipVector zv

---------- Category-theoretic instances ----------

instance Fact m => Functor (CT m) where
  -- Functor instance is implied by Applicative laws
  fmap f x = pure f <*> x

instance Fact m => Applicative (CT m) where
  pure = ZV . pure

  (ZV f) <*> (ZV a) = ZV (f <*> a)
  f@(ZV _) <*> v@(CT _) = f <*> toZV v

instance Fact m => Foldable (CT m) where
  -- Foldable instance is implied by Traversable
  foldMap = foldMapDefault

instance Fact m => Traversable (CT m) where
  traverse f r@(CT _) = T.traverse f $ toZV r
  traverse f (ZV v) = ZV <$> T.traverse f v

instance Tensor CT where

  type TElt CT r = (Storable r, Dispatch r)

  entailIndexT = tag $ Sub Dict
  entailEqT = tag $ Sub Dict
  entailZTT = tag $ Sub Dict
  -- entailRingT = tag $ Sub Dict
  entailNFDataT = tag $ Sub Dict
  entailRandomT = tag $ Sub Dict
  entailShowT = tag $ Sub Dict
  entailModuleT = tag $ Sub Dict

  scalarPow = CT . scalarPow' -- Vector code

  l = wrap $ untag $ basicDispatch dl
  lInv = wrap $ untag $ basicDispatch dlinv

  mulGPow = wrap mulGPow'
  mulGDec = wrap $ untag $ basicDispatch dmulgdec

  divGPow = wrapM divGPow'
  -- we divide by p in the C code (for divGDec only(?)), do NOT call checkDiv!
  divGDec = wrapM $ Just . untag (basicDispatch dginvdec)

  crtFuncs = (,,,,) <$>
    return (CT . repl) <*>
    (wrap . untag (cZipDispatch dmul) <$> gCRT) <*>
    (wrap . untag (cZipDispatch dmul) <$> gInvCRT) <*>
    (wrap <$> untagT ctCRT) <*>
    (wrap <$> untagT ctCRTInv)

  twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
  embedPow = wrap $ runIdentity $ coerceEm embedPow'
  embedDec = wrap $ runIdentity $ coerceEm embedDec'

  tGaussianDec v = CT <$> cDispatchGaussian v
  --tGaussianDec v = CT <$> coerceT' (gaussianDec v)

  -- we do not wrap this function because (currently) it can only be called on lifted types
  gSqNormDec (CT v) = untag gSqNormDec' v
  gSqNormDec (ZV v) = gSqNormDec (CT $ zvToCT' v)

  crtExtFuncs = (,) <$> (wrap <$> coerceTw twaceCRT')
                    <*> (wrap <$> coerceEm embedCRT')

  coeffs = wrapM $ coerceCoeffs coeffs'

  powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'

  crtSetDec = (CT <$>) <$> coerceBasis crtSetDec'

  fmapT f (CT v) = CT $ coerce (SV.map f) v
  fmapT f v@(ZV _) = fmapT f $ toCT v

  fmapTM f (CT (CT' v)) = (CT . CT') <$> SV.mapM f v
  fmapTM f v@(ZV _) = fmapTM f $ toCT v

  zipWithT f (CT (CT' v1)) (CT (CT' v2)) = CT $ CT' $ SV.zipWith f v1 v2
  zipWithT f v1 v2 = zipWithT f (toCT v1) (toCT v2)

  unzipT (CT (CT' v)) = (CT . CT') *** (CT . CT') $ unzip v
  unzipT v = unzipT $ toCT v

  {-# INLINABLE entailIndexT #-}
  {-# INLINABLE entailEqT #-}
  {-# INLINABLE entailZTT #-}
  {-# INLINABLE entailNFDataT #-}
  {-# INLINABLE entailRandomT #-}
  {-# INLINABLE entailShowT #-}
  {-# INLINABLE scalarPow #-}
  {-# INLINABLE l #-}
  {-# INLINABLE lInv #-}
  {-# INLINABLE mulGPow #-}
  {-# INLINABLE mulGDec #-}
  {-# INLINABLE divGPow #-}
  {-# INLINABLE divGDec #-}
  {-# INLINABLE crtFuncs #-}
  {-# INLINABLE twacePowDec #-}
  {-# INLINABLE embedPow #-}
  {-# INLINABLE embedDec #-}
  {-# INLINABLE tGaussianDec #-}
  {-# INLINABLE gSqNormDec #-}
  {-# INLINABLE crtExtFuncs #-}
  {-# INLINABLE coeffs #-}
  {-# INLINABLE powBasisPow #-}
  {-# INLINABLE crtSetDec #-}
  {-# INLINABLE fmapT #-}
  {-# INLINABLE fmapTM #-}
  {-# INLINABLE zipWithT #-}
  {-# INLINABLE unzipT #-}


coerceTw :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m' r -> CT' m r)
coerceTw = (coerce <$>) . untagT

coerceEm :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m r -> CT' m' r)
coerceEm = (coerce <$>) . untagT

-- | Useful coersion for defining @coeffs@ in the @Tensor@
-- interface. Using 'coerce' alone is insufficient for type inference.
coerceCoeffs :: (Fact m, Fact m')
  => Tagged '(m,m') (Vector r -> [Vector r]) -> CT' m' r -> [CT' m r]
coerceCoeffs = coerce

-- | Useful coersion for defining @powBasisPow@ and @crtSetDec@ in the @Tensor@
-- interface. Using 'coerce' alone is insufficient for type inference.
coerceBasis ::
  (Fact m, Fact m')
  => Tagged '(m,m') [Vector r] -> Tagged m [CT' m' r]
coerceBasis = coerce

mulGPow' :: (TElt CT r, Fact m, Additive r) => CT' m r -> CT' m r
mulGPow' = untag $ basicDispatch dmulgpow

divGPow' :: (TElt CT r, Fact m, IntegralDomain r, ZeroTestable r)
            => CT' m r -> Maybe (CT' m r)
divGPow' = untag $ checkDiv $ basicDispatch dginvpow

withBasicArgs :: forall m r . (Fact m, Storable r)
  => (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
     -> CT' m r -> IO (CT' m r)
withBasicArgs f =
  let factors = proxy (marshalFactors <$> ppsFact) (Proxy::Proxy m)
      totm = proxy (fromIntegral <$> totientFact) (Proxy::Proxy m)
      numFacts = fromIntegral $ SV.length factors
  in \(CT' x) -> do
    yout <- SV.thaw x
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith factors (\pfac ->
        f pout totm pfac numFacts))
    CT' <$> unsafeFreeze yout

basicDispatch :: (Storable r, Fact m, Additive r)
                 => (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
                     -> Tagged m (CT' m r -> CT' m r)
basicDispatch f = return $ unsafePerformIO . withBasicArgs f

gSqNormDec' :: (Storable r, Fact m, Additive r, Dispatch r)
               => Tagged m (CT' m r -> r)
gSqNormDec' = return $ (!0) . unCT . unsafePerformIO . withBasicArgs dnorm

ctCRT :: (Storable r, CRTrans mon r, Dispatch r, Fact m)
         => TaggedT m mon (CT' m r -> CT' m r)
ctCRT = do
  ru' <- ru
  return $ \x -> unsafePerformIO $
    withPtrArray ru' (flip withBasicArgs x . dcrt)

-- CTensor CRT^(-1) functions take inverse rus
ctCRTInv :: (Storable r, CRTrans mon r, Dispatch r, Fact m)
         => TaggedT m mon (CT' m r -> CT' m r)
ctCRTInv = do
  mhatInv <- snd <$> crtInfo
  ruinv' <- ruInv
  return $ \x -> unsafePerformIO $
    withPtrArray ruinv' (\ruptr -> with mhatInv (flip withBasicArgs x . dcrtinv ruptr))

checkDiv :: (Storable r, IntegralDomain r, ZeroTestable r, Fact m)
    => Tagged m (CT' m r -> CT' m r) -> Tagged m (CT' m r -> Maybe (CT' m r))
checkDiv f = do
  f' <- f
  oddRad' <- fromIntegral <$> oddRadicalFact
  return $ \x ->
    let (CT' y) = f' x
    in CT' <$> SV.mapM (`divIfDivis` oddRad') y

divIfDivis :: (IntegralDomain r, ZeroTestable r) => r -> r -> Maybe r
divIfDivis num den = let (q,r) = num `divMod` den
                     in if isZero r then Just q else Nothing

cZipDispatch :: (Storable r, Fact m, Additive r)
  => (Ptr r -> Ptr r -> Int64 -> IO ())
     -> Tagged m (CT' m r -> CT' m r -> CT' m r)
cZipDispatch f = do -- in Tagged m
  totm <- fromIntegral <$> totientFact
  return $ coerce $ \a b -> unsafePerformIO $ do
    yout <- SV.thaw a
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith b (\pin ->
        f pout pin totm))
    unsafeFreeze yout

cDispatchGaussian :: forall m r var rnd .
         (Storable r, Transcendental r, Dispatch r, Ord r,
          Fact m, ToRational var, Random r, MonadRandom rnd)
         => var -> rnd (CT' m r)
cDispatchGaussian var = flip proxyT (Proxy::Proxy m) $ do -- in TaggedT m rnd
  -- get rus for (Complex r)
  ruinv' <- mapTaggedT (return . fromMaybe (error "complexGaussianRoots")) ruInv
  totm <- pureT totientFact
  m <- pureT valueFact
  rad <- pureT radicalFact
  yin <- T.lift $ realGaussians (var * fromIntegral (m `div` rad)) totm
  return $ unsafePerformIO $
    withPtrArray ruinv' (\ruptr -> withBasicArgs (dgaussdec ruptr) (CT' yin))

instance (Arbitrary r, Fact m, Storable r) => Arbitrary (CT' m r) where
  arbitrary = replM arbitrary
  shrink = shrinkNothing

instance (Storable r, Arbitrary (CT' m r)) => Arbitrary (CT m r) where
  arbitrary = CT <$> arbitrary

instance (Storable r, Random r, Fact m) => Random (CT' m r) where
  --{-# INLINABLE random #-}
  random = runRand $ replM (liftRand random)

  randomR = error "randomR nonsensical for CT'"

instance (Storable r, Random (CT' m r)) => Random (CT m r) where
  --{-# INLINABLE random #-}
  random = runRand $ CT <$> liftRand random

  randomR = error "randomR nonsensical for CT"

instance (NFData r) => NFData (CT m r) where
  rnf (CT v) = rnf v
  rnf (ZV v) = rnf v

repl :: forall m r . (Fact m, Storable r) => r -> CT' m r
repl = let n = proxy totientFact (Proxy::Proxy m)
       in coerce . SV.replicate n

replM :: forall m r mon . (Fact m, Storable r, Monad mon)
         => mon r -> mon (CT' m r)
replM = let n = proxy totientFact (Proxy::Proxy m)
        in fmap coerce . SV.replicateM n

scalarPow' :: forall m r . (Fact m, Additive r, Storable r) => r -> CT' m r
-- constant-term coefficient is first entry wrt powerful basis
scalarPow' =
  let n = proxy totientFact (Proxy::Proxy m)
  in \r -> CT' $ generate n (\i -> if i == 0 then r else zero)

ru, ruInv :: (CRTrans mon r, Fact m, Storable r)
   => TaggedT m mon [Vector r]
ru = do
  mval <- pureT valueFact
  wPow <- fst <$> crtInfo
  LP.map
    (\(p,e) -> do
        let pp = p^e
            pow = mval `div` pp
        generate pp (wPow . (*pow))) <$>
      pureT ppsFact

ruInv = do
  mval <- pureT valueFact
  wPow <- fst <$> crtInfo
  LP.map
    (\(p,e) -> do
        let pp = p^e
            pow = mval `div` pp
        generate pp (\i -> wPow $ -i*pow)) <$>
      pureT ppsFact

wrapVector :: forall mon m r . (Monad mon, Fact m, Ring r, Storable r)
  => TaggedT m mon (Matrix r) -> mon (CT' m r)
wrapVector v = do
  vmat <- proxyT v (Proxy::Proxy m)
  let n = proxy totientFact (Proxy::Proxy m)
  return $ CT' $ generate n (flip (indexM vmat) 0)

gCRT, gInvCRT :: (Storable r, CRTrans mon r, Fact m)
                 => mon (CT' m r)
gCRT = wrapVector gCRTM
gInvCRT = wrapVector gInvCRTM

-- we can't put this in Extension with the rest of the twace/embed
-- functions because it needs access to the C backend
twaceCRT' :: forall mon m m' r .
             (TElt CT r, CRTrans mon r, m `Divides` m')
             => TaggedT '(m, m') mon (Vector r -> Vector r)
twaceCRT' = tagT $ do
  (CT' g') :: CT' m' r <- gCRT
  (CT' gInv) :: CT' m r <- gInvCRT
  embed <- proxyT embedCRT' (Proxy::Proxy '(m,m'))
  indices <- pure $ proxy extIndicesCRT (Proxy::Proxy '(m,m'))
  (_, m'hatinv) <- proxyT crtInfo (Proxy::Proxy m')
  let phi = proxy totientFact (Proxy::Proxy m)
      phi' = proxy totientFact (Proxy::Proxy m')
      mhat = fromIntegral $ proxy valueHatFact (Proxy::Proxy m)
      hatRatioInv = m'hatinv * mhat
      reltot = phi' `div` phi
      -- tweak = mhat * g' / (m'hat * g)
      tweak = SV.map (* hatRatioInv) $ SV.zipWith (*) (embed gInv) g'
  return $ \ arr -> -- take true trace after mul-by-tweak
    let v = backpermute' indices (SV.zipWith (*) tweak arr)
    in generate phi $ \i -> foldl1' (+) $ SV.unsafeSlice (i*reltot) reltot v