{-|
Module      : Crypto.Lol.Cyclotomic.Tensor.CPP.Backend
Description : Transforms Haskell types into C counterparts.
Copyright   : (c) Eric Crockett, 2011-2017
                  Chris Peikert, 2011-2018
License     : GPL-3
Maintainer  : ecrockett0@gmail.com
Stability   : experimental
Portability : POSIX

This module contains the functions to transform Haskell types into their
C counterpart, and to transform polymorphic Haskell functions into C funtion
calls in a type-safe way.
-}

{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}

{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}

module Crypto.Lol.Cyclotomic.Tensor.CPP.Backend
( dcrtZq, dcrtinvZq
, dlZq, dlinvZq
, dmulgpowZq, dmulgdecZq
, dginvpowZq, dginvdecZq
, dmulZq
, dcrtC, dcrtinvC
, dlC, dlinvC
, dmulgpowC, dmulgdecC
, dginvpowC, dginvdecC
, dmulC
, dlDouble, dlinvDouble
, dmulgpowDouble, dmulgdecDouble
, dginvpowDouble, dginvdecDouble
, dgaussdecDouble
, dnormDouble
, dlRRq,dlinvRRq
, dlInt64, dlinvInt64
, dmulgpowInt64, dmulgdecInt64
, dginvpowInt64, dginvdecInt64
, dnormInt64
, marshalFactors
, CPP
, withArray, withPtrArray
) where

import Crypto.Lol.Prelude              as LP (Complex, PP, map, mapM_)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.Unsafe.RRq
import Crypto.Lol.Types.Unsafe.ZqBasic

import Control.Arrow                 ((***))
import Data.Int
import Data.Vector.Storable          as SV (Vector, fromList,
                                            unsafeToForeignPtr0)
import Data.Vector.Storable.Internal (getPtr)

import Foreign.ForeignPtr    (touchForeignPtr)
import Foreign.Marshal.Array (withArray)
import Foreign.Ptr           (Ptr, castPtr, plusPtr)
import Foreign.Storable      (Storable (..))

-- | Convert a list of prime powers to a suitable C representation.
marshalFactors :: [PP] -> Vector CPP
marshalFactors = SV.fromList . LP.map (fromIntegral *** fromIntegral)

-- http://stackoverflow.com/questions/6517387/vector-vector-foo-ptr-ptr-foo-io-a-io-a
-- | Evaluates a C function that takes an "a** ptr" on a list of Vectors.
withPtrArray :: (Storable a) => [Vector a] -> (Ptr (Ptr a) -> IO b) -> IO b
withPtrArray v f = do
  let vs = LP.map SV.unsafeToForeignPtr0 v
      ptrV = LP.map (\(fp,_) -> getPtr fp) vs
  res <- withArray ptrV f
  LP.mapM_ (\(fp,_) -> touchForeignPtr fp) vs
  return res

-- Note: These types need to be the same, otherwise something goes wrong on the C end...
-- | C representation of a prime power.
type CPP = (Int16, Int16)

instance (Storable a, Storable b)
  => Storable (a,b) where
  sizeOf _ = sizeOf (undefined :: a) + sizeOf (undefined :: b)
  alignment _ = max (alignment (undefined :: a)) (alignment (undefined :: b))
  peek p = do
    a <- peek (castPtr p :: Ptr a)
    b <- peek (castPtr (plusPtr p (sizeOf a)) :: Ptr b)
    return (a,b)
  poke p (a,b) = do
    poke (castPtr p :: Ptr a) a
    poke (castPtr (plusPtr p (sizeOf a)) :: Ptr b) b

dcrtZq :: forall q . Reflects q Int64 => Ptr (Ptr (ZqBasic q Int64)) -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dcrtZq ruptr pout totm pfac numFacts =
  tensorCRTRq (castPtr pout) totm pfac numFacts (castPtr ruptr) (value @q)

dcrtinvZq :: forall q . Reflects q Int64 => Ptr (Ptr (ZqBasic q Int64)) -> Ptr (ZqBasic q Int64) ->  Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dcrtinvZq ruptr minv pout totm pfac numFacts =
  tensorCRTInvRq (castPtr pout) totm pfac numFacts (castPtr ruptr) (castPtr minv) (value @q)

dlZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlZq pout totm pfac numFacts =
  tensorLRq (castPtr pout) totm pfac numFacts (value @q)

dlinvZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlinvZq pout totm pfac numFacts =
  tensorLInvRq (castPtr pout) totm pfac numFacts (value @q)

dmulgpowZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgpowZq pout totm pfac numFacts =
  tensorGPowRq (castPtr pout) totm pfac numFacts (value @q)

dmulgdecZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgdecZq pout totm pfac numFacts =
  tensorGDecRq (castPtr pout) totm pfac numFacts (value @q)

dginvpowZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvpowZq pout totm pfac numFacts =
  tensorGInvPowRq (castPtr pout) totm pfac numFacts (value @q)

dginvdecZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvdecZq pout totm pfac numFacts =
  tensorGInvDecRq (castPtr pout) totm pfac numFacts (value @q)

dmulZq :: forall q . Reflects q Int64 => Ptr (ZqBasic q Int64) -> Ptr (ZqBasic q Int64) -> Int64 -> IO ()
dmulZq aout bout totm =
  mulRq (castPtr aout) (castPtr bout) totm (value @q)

dcrtC :: Ptr (Ptr (Complex Double)) -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dcrtC ruptr pout totm pfac numFacts =
  tensorCRTC (castPtr pout) totm pfac numFacts (castPtr ruptr)

dcrtinvC :: Ptr (Ptr (Complex Double)) -> Ptr (Complex Double) -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dcrtinvC ruptr minv pout totm pfac numFacts =
  tensorCRTInvC (castPtr pout) totm pfac numFacts (castPtr ruptr) (castPtr minv)

dlC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlC pout = tensorLC (castPtr pout)

dlinvC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlinvC pout = tensorLInvC (castPtr pout)

dmulgpowC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgpowC pout = tensorGPowC (castPtr pout)

dmulgdecC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgdecC pout = tensorGDecC (castPtr pout)

dginvpowC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvpowC pout = tensorGInvPowC (castPtr pout)

dginvdecC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvdecC pout = tensorGInvDecC (castPtr pout)

dmulC :: Ptr (Complex Double) -> Ptr (Complex Double) -> Int64 -> IO ()
dmulC aout bout = mulC (castPtr aout) (castPtr bout)

-- q is nominal. C++ never sees it, so it doesn't matter what it is
dlRRq :: forall q . Ptr (RRq q Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlRRq pout = tensorLRRq (castPtr pout)

-- q is nominal. C++ never sees it, so it doesn't matter what it is
dlinvRRq :: forall q . Ptr (RRq q Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlinvRRq pout = tensorLInvRRq (castPtr pout)

dlDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlDouble pout = tensorLDouble (castPtr pout)

dlinvDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlinvDouble pout = tensorLInvDouble (castPtr pout)

dnormDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
dnormDouble pout = tensorNormSqD (castPtr pout)

dmulgpowDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgpowDouble pout = tensorGPowDouble (castPtr pout)

dmulgdecDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgdecDouble pout = tensorGDecDouble (castPtr pout)

dginvpowDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvpowDouble pout = tensorGInvPowDouble (castPtr pout)

dginvdecDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvdecDouble pout = tensorGInvDecDouble (castPtr pout)

dgaussdecDouble :: Ptr (Ptr (Complex Double)) -> Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
dgaussdecDouble ruptr pout totm pfac numFacts =
  tensorGaussianDec (castPtr pout) totm pfac numFacts (castPtr ruptr)

dlInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlInt64 pout = tensorLR (castPtr pout)

dlinvInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
dlinvInt64 pout = tensorLInvR (castPtr pout)

dnormInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
dnormInt64 pout = tensorNormSqR (castPtr pout)

dmulgpowInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgpowInt64 pout = tensorGPowR (castPtr pout)

dmulgdecInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
dmulgdecInt64 pout = tensorGDecR (castPtr pout)

dginvpowInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvpowInt64 pout = tensorGInvPowR (castPtr pout)

dginvdecInt64 :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
dginvdecInt64 pout = tensorGInvDecR (castPtr pout)

foreign import ccall unsafe "tensorLR" tensorLR ::                  Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvR" tensorLInvR ::            Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLRq" tensorLRq ::                Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorLInvRq" tensorLInvRq ::          Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorLDouble" tensorLDouble ::       Ptr Double -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvDouble" tensorLInvDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLRRq" tensorLRRq ::       Ptr (RRq q Double) -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvRRq" tensorLInvRRq :: Ptr (RRq q Double) -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLC" tensorLC ::       Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvC" tensorLInvC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16          -> IO ()

foreign import ccall unsafe "tensorNormSqR" tensorNormSqR :: Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorNormSqD" tensorNormSqD :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()

foreign import ccall unsafe "tensorGPowR" tensorGPowR ::                 Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGPowRq" tensorGPowRq ::               Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorGPowDouble" tensorGPowDouble ::       Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGPowC" tensorGPowC ::                 Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGDecR" tensorGDecR ::                 Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGDecRq" tensorGDecRq ::               Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorGDecDouble" tensorGDecDouble ::       Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGDecC" tensorGDecC ::                 Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR ::           Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq ::         Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO Int16
foreign import ccall unsafe "tensorGInvPowDouble" tensorGInvPowDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvPowC" tensorGInvPowC ::           Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR ::           Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq ::         Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO Int16
foreign import ccall unsafe "tensorGInvDecDouble" tensorGInvDecDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecC" tensorGInvDecC ::           Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16

foreign import ccall unsafe "tensorCRTRq" tensorCRTRq ::         Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Int64 -> IO ()
foreign import ccall unsafe "tensorCRTC" tensorCRTC ::           Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()
foreign import ccall unsafe "tensorCRTInvRq" tensorCRTInvRq ::   Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Ptr (ZqBasic q Int64) -> Int64 -> IO ()
foreign import ccall unsafe "tensorCRTInvC" tensorCRTInvC ::     Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> Ptr (Complex Double) -> IO ()

foreign import ccall unsafe "tensorGaussianDec" tensorGaussianDec :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) ->  IO ()

foreign import ccall unsafe "mulRq" mulRq :: Ptr (ZqBasic q Int64) -> Ptr (ZqBasic q Int64) -> Int64 -> Int64 -> IO ()
foreign import ccall unsafe "mulC" mulC :: Ptr (Complex Double) -> Ptr (Complex Double) -> Int64 -> IO ()