{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Math.Lattices.Fplll.Internal where

import Control.Exception
import Data.Foldable
import Data.Maybe
import Foreign.C.String
import Foreign.C.Types
import Foreign.Marshal.Array
import Foreign.Ptr
import Foreign.Storable
import Numeric.GMP.Raw.Unsafe
import Numeric.GMP.Types
import Numeric.GMP.Utils

newtype LLLMethod = LLLMethod CInt deriving (Eq, Ord, Storable)
newtype FloatType = FloatType CInt deriving (Eq, Ord, Storable)
newtype RedStatus = RedStatus CInt deriving (Eq, Ord, Storable)

-- | Flags controlling LLL reduction. Can be combined using 'Algebra.Lattice.\/'.
newtype LLLFlags  = LLLFlags  CInt deriving (Eq, Ord, Storable)


allocaMpz :: Int -> (Ptr MPZ -> IO a) -> IO a
allocaMpz len f =
  allocaArray len $ \p ->
    bracket_
      (traverse_ (mpz_init . advancePtr p) [0 .. len-1])
      (traverse_ (mpz_clear . advancePtr p) [0 .. len-1])
      (f p)

peekBasis :: Int -> Int -> Ptr MPZ -> IO [[Integer]]
peekBasis vecs len bArr =
  sequenceA [sequenceA [readVal i j | j <- [0 .. len-1]] | i <- [0 .. vecs-1]]
  where
    readVal i j = peekInteger (advancePtr bArr (i * len + j))

pokeBasis :: Int -> Int -> Ptr MPZ -> [[Integer]] -> IO ()
pokeBasis vecs len bArr b =
  sequenceA_ $ concat $ zipWith (\i vec -> zipWith (writeVal i) [0 .. len-1] vec) [0 .. vecs-1] b
  where
    writeVal i j = pokeInteger (advancePtr bArr (i * len + j))

allocaAndPokeBasis :: [[Integer]] -> (Int -> Int -> Ptr MPZ -> IO a) -> IO a
allocaAndPokeBasis b f
  | allEqual $ length <$> b =
    allocaMpz (vecs * len) $ \bArr -> do
      sequenceA_ $ concat $ zipWith (\i vec -> zipWith (writeVal bArr i) [0 ..] vec) [0 ..] b
      f vecs len bArr
  where
    writeVal bArr i j = pokeInteger (advancePtr bArr (i * len + j))
    vecs = length b
    len = maybe 0 length (listToMaybe b)
allocaAndPokeBasis _ _ = error "Basis vectors aren't all the same length"

allEqual :: (Eq a) => [a] -> Bool
allEqual x = all (== head x) (drop 1 x)


foreign import ccall "&lllDefaultDelta" c_lllDefaultDelta :: Ptr CDouble
foreign import ccall "&lllDefaultEta" c_lllDefaultEta :: Ptr CDouble

foreign import ccall "&lmWrapper" c_lmWrapper :: Ptr CInt
foreign import ccall "&lmProved" c_lmProved :: Ptr CInt
foreign import ccall "&lmHeuristic" c_lmHeuristic :: Ptr CInt
foreign import ccall "&lmFast" c_lmFast :: Ptr CInt

foreign import ccall "&lllMethodStr" c_lllMethodStr :: Ptr (Ptr CString)

foreign import ccall "&lllVerbose" c_lllVerbose :: Ptr CInt
foreign import ccall "&lllEarlyRed" c_lllEarlyRed :: Ptr CInt
foreign import ccall "&lllSiegel" c_lllSiegel :: Ptr CInt
foreign import ccall "&lllDefault" c_lllDefault :: Ptr CInt

foreign import ccall "&ftDefault" c_ftDefault :: Ptr CInt
foreign import ccall "&ftDouble" c_ftDouble :: Ptr CInt
foreign import ccall "&ftLongDouble" c_ftLongDouble :: Ptr CInt
foreign import ccall "&ftDpe" c_ftDpe :: Ptr CInt
foreign import ccall "&ftDD" c_ftDD :: Ptr CInt
foreign import ccall "&ftQD" c_ftQD :: Ptr CInt
foreign import ccall "&ftMpfr" c_ftMpfr :: Ptr CInt

foreign import ccall "&floatTypeStr" c_floatTypeStr :: Ptr (Ptr CString)

foreign import ccall "&redSuccess" c_redSuccess :: Ptr CInt
foreign import ccall "&redGsoFailure" c_redGsoFailure :: Ptr CInt
foreign import ccall "&redBabaiFailure" c_redBabaiFailure :: Ptr CInt
foreign import ccall "&redLllFailure" c_redLllFailure :: Ptr CInt
foreign import ccall "&redEnumFailure" c_redEnumFailure :: Ptr CInt
foreign import ccall "&redBkzFailure" c_redBkzFailure :: Ptr CInt
foreign import ccall "&redBkzTimeLimit" c_redBkzTimeLimit :: Ptr CInt
foreign import ccall "&redBkzLoopsLimit" c_redBkzLoopsLimit :: Ptr CInt
foreign import ccall "&redHlllFailure" c_redHlllFailure :: Ptr CInt
foreign import ccall "&redHlllNormFailure" c_redHlllNormFailure :: Ptr CInt
foreign import ccall "&redHlllSrFailure" c_redHlllSrFailure :: Ptr CInt

foreign import ccall "&redStatusStr" c_redStatusStr :: Ptr (Ptr CString)

foreign import ccall unsafe "hs_ffi_lll_reduction"
  c_lll_reduction :: CInt -> CInt -> Ptr MPZ -> CDouble -> CDouble -> LLLMethod -> FloatType -> CInt
                     -> LLLFlags -> IO RedStatus
foreign import ccall unsafe "hs_ffi_lll_reduction_u_id"
  c_lll_reduction_u_id :: CInt -> CInt -> Ptr MPZ -> Ptr MPZ -> CDouble -> CDouble -> LLLMethod
                       -> FloatType -> CInt -> LLLFlags -> IO RedStatus
foreign import ccall unsafe "hs_ffi_lll_reduction_uinv_id"
  c_lll_reduction_uinv_id :: CInt -> CInt -> Ptr MPZ -> Ptr MPZ -> Ptr MPZ -> CDouble -> CDouble
                          -> LLLMethod -> FloatType -> CInt -> LLLFlags -> IO RedStatus