{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Math.Lattices.Fplll.Internal where import Control.Exception import Data.Foldable import Data.Maybe import Foreign.C.String import Foreign.C.Types import Foreign.Marshal.Array import Foreign.Ptr import Foreign.Storable import Numeric.GMP.Raw.Unsafe import Numeric.GMP.Types import Numeric.GMP.Utils newtype LLLMethod = LLLMethod CInt deriving (Eq, Ord, Storable) newtype FloatType = FloatType CInt deriving (Eq, Ord, Storable) newtype RedStatus = RedStatus CInt deriving (Eq, Ord, Storable) -- | Flags controlling LLL reduction. Can be combined using 'Algebra.Lattice.\/'. newtype LLLFlags = LLLFlags CInt deriving (Eq, Ord, Storable) allocaMpz :: Int -> (Ptr MPZ -> IO a) -> IO a allocaMpz len f = allocaArray len $ \p -> bracket_ (traverse_ (mpz_init . advancePtr p) [0 .. len-1]) (traverse_ (mpz_clear . advancePtr p) [0 .. len-1]) (f p) peekBasis :: Int -> Int -> Ptr MPZ -> IO [[Integer]] peekBasis vecs len bArr = sequenceA [sequenceA [readVal i j | j <- [0 .. len-1]] | i <- [0 .. vecs-1]] where readVal i j = peekInteger (advancePtr bArr (i * len + j)) pokeBasis :: Int -> Int -> Ptr MPZ -> [[Integer]] -> IO () pokeBasis vecs len bArr b = sequenceA_ $ concat $ zipWith (\i vec -> zipWith (writeVal i) [0 .. len-1] vec) [0 .. vecs-1] b where writeVal i j = pokeInteger (advancePtr bArr (i * len + j)) allocaAndPokeBasis :: [[Integer]] -> (Int -> Int -> Ptr MPZ -> IO a) -> IO a allocaAndPokeBasis b f | allEqual $ length <$> b = allocaMpz (vecs * len) $ \bArr -> do sequenceA_ $ concat $ zipWith (\i vec -> zipWith (writeVal bArr i) [0 ..] vec) [0 ..] b f vecs len bArr where writeVal bArr i j = pokeInteger (advancePtr bArr (i * len + j)) vecs = length b len = maybe 0 length (listToMaybe b) allocaAndPokeBasis _ _ = error "Basis vectors aren't all the same length" allEqual :: (Eq a) => [a] -> Bool allEqual x = all (== head x) (drop 1 x) foreign import ccall "&lllDefaultDelta" c_lllDefaultDelta :: Ptr CDouble foreign import ccall "&lllDefaultEta" c_lllDefaultEta :: Ptr CDouble foreign import ccall "&lmWrapper" c_lmWrapper :: Ptr CInt foreign import ccall "&lmProved" c_lmProved :: Ptr CInt foreign import ccall "&lmHeuristic" c_lmHeuristic :: Ptr CInt foreign import ccall "&lmFast" c_lmFast :: Ptr CInt foreign import ccall "&lllMethodStr" c_lllMethodStr :: Ptr (Ptr CString) foreign import ccall "&lllVerbose" c_lllVerbose :: Ptr CInt foreign import ccall "&lllEarlyRed" c_lllEarlyRed :: Ptr CInt foreign import ccall "&lllSiegel" c_lllSiegel :: Ptr CInt foreign import ccall "&lllDefault" c_lllDefault :: Ptr CInt foreign import ccall "&ftDefault" c_ftDefault :: Ptr CInt foreign import ccall "&ftDouble" c_ftDouble :: Ptr CInt foreign import ccall "&ftLongDouble" c_ftLongDouble :: Ptr CInt foreign import ccall "&ftDpe" c_ftDpe :: Ptr CInt foreign import ccall "&ftDD" c_ftDD :: Ptr CInt foreign import ccall "&ftQD" c_ftQD :: Ptr CInt foreign import ccall "&ftMpfr" c_ftMpfr :: Ptr CInt foreign import ccall "&floatTypeStr" c_floatTypeStr :: Ptr (Ptr CString) foreign import ccall "&redSuccess" c_redSuccess :: Ptr CInt foreign import ccall "&redGsoFailure" c_redGsoFailure :: Ptr CInt foreign import ccall "&redBabaiFailure" c_redBabaiFailure :: Ptr CInt foreign import ccall "&redLllFailure" c_redLllFailure :: Ptr CInt foreign import ccall "&redEnumFailure" c_redEnumFailure :: Ptr CInt foreign import ccall "&redBkzFailure" c_redBkzFailure :: Ptr CInt foreign import ccall "&redBkzTimeLimit" c_redBkzTimeLimit :: Ptr CInt foreign import ccall "&redBkzLoopsLimit" c_redBkzLoopsLimit :: Ptr CInt foreign import ccall "&redHlllFailure" c_redHlllFailure :: Ptr CInt foreign import ccall "&redHlllNormFailure" c_redHlllNormFailure :: Ptr CInt foreign import ccall "&redHlllSrFailure" c_redHlllSrFailure :: Ptr CInt foreign import ccall "&redStatusStr" c_redStatusStr :: Ptr (Ptr CString) foreign import ccall unsafe "hs_ffi_lll_reduction" c_lll_reduction :: CInt -> CInt -> Ptr MPZ -> CDouble -> CDouble -> LLLMethod -> FloatType -> CInt -> LLLFlags -> IO RedStatus foreign import ccall unsafe "hs_ffi_lll_reduction_u_id" c_lll_reduction_u_id :: CInt -> CInt -> Ptr MPZ -> Ptr MPZ -> CDouble -> CDouble -> LLLMethod -> FloatType -> CInt -> LLLFlags -> IO RedStatus foreign import ccall unsafe "hs_ffi_lll_reduction_uinv_id" c_lll_reduction_uinv_id :: CInt -> CInt -> Ptr MPZ -> Ptr MPZ -> Ptr MPZ -> CDouble -> CDouble -> LLLMethod -> FloatType -> CInt -> LLLFlags -> IO RedStatus