{-|
Module      : Crypto.Secp256k1.Internal
Description : Internal SECP256K1 cryptographic functions
License     : PublicDomain
Maintainer  : root@haskoin.com
Stability   : experimental
Portability : POSIX

The API in this module may change at any time.  This is an internal module
only exposed for hacking and experimentation.
-}
module Crypto.Secp256k1.Internal where

import           Control.Monad

import           Data.ByteString        (ByteString, packCStringLen)
import           Data.ByteString.Unsafe (unsafeUseAsCStringLen)

import           Foreign
import           Foreign.C

import           System.Entropy
import           System.IO.Unsafe

data Ctx = Ctx

newtype PubKey64 = PubKey64 { getPubKey64 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype Msg32 = Msg32 { getMsg32 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype Sig64 = Sig64 { getSig64 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype Seed32 = Seed32 { getSeed32 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype SecKey32 = SecKey32 { getSecKey32 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype Tweak32 = Tweak32 { getTweak32 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype Nonce32 = Nonce32 { getNonce32 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype Algo16 = Algo16 { getAlgo16 :: ByteString }
    deriving (Read, Show, Eq, Ord)

newtype CtxFlags = CtxFlags { getCtxFlags :: CUInt }
    deriving (Read, Show, Eq, Ord)

newtype SerFlags = SerFlags { getSerFlags :: CUInt }
    deriving (Read, Show, Eq, Ord)

newtype Ret = Ret { getRet :: CInt }
    deriving (Read, Show, Eq, Ord)

-- | Nonce32-generating function
type NonceFunction a
    =  Ptr Nonce32
    -> Ptr Msg32
    -> Ptr SecKey32
    -> Ptr Algo16
    -> Ptr a       -- ^ extra data
    -> CUInt       -- ^ attempt
    -> Ret

verify :: CtxFlags
verify = CtxFlags 1

sign :: CtxFlags
sign = CtxFlags 2

signVerify :: CtxFlags
signVerify = CtxFlags 3

compressed :: SerFlags
compressed = SerFlags 1

uncompressed :: SerFlags
uncompressed = SerFlags 0

useByteString :: ByteString -> ((Ptr CUChar, CSize) -> IO a) -> IO a
useByteString bs f =
    unsafeUseAsCStringLen bs $ \(b, l) -> f (castPtr b, fromIntegral l)

packByteString :: (Ptr CUChar, CSize) -> IO ByteString
packByteString (b, l) = packCStringLen (castPtr b, fromIntegral l)

instance Storable PubKey64 where
    sizeOf _ = 64
    alignment _ = 1
    peek p = PubKey64 <$> packByteString (castPtr p, 64)
    poke p (PubKey64 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 64

instance Storable Sig64 where
    sizeOf _ = 64
    alignment _ = 1
    peek p = Sig64 <$> packByteString (castPtr p, 64)
    poke p (Sig64 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 64

instance Storable Msg32 where
    sizeOf _ = 32
    alignment _ = 1
    peek p = Msg32 <$> packByteString (castPtr p, 32)
    poke p (Msg32 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 32

instance Storable Seed32 where
    sizeOf _ = 32
    alignment _ = 1
    peek p = Seed32 <$> packByteString (castPtr p, 32)
    poke p (Seed32 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 32

instance Storable SecKey32 where
    sizeOf _ = 32
    alignment _ = 1
    peek p = SecKey32 <$> packByteString (castPtr p, 32)
    poke p (SecKey32 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 32

instance Storable Tweak32 where
    sizeOf _ = 32
    alignment _ = 1
    peek p = Tweak32 <$> packByteString (castPtr p, 32)
    poke p (Tweak32 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 32

instance Storable Nonce32 where
    sizeOf _ = 32
    alignment _ = 1
    peek p = Nonce32 <$> packByteString (castPtr p, 32)
    poke p (Nonce32 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 32

instance Storable Algo16 where
    sizeOf _ = 16
    alignment _ = 1
    peek p = Algo16 <$> packByteString (castPtr p, 16)
    poke p (Algo16 k) = useByteString k $
        \(b, _) -> copyArray (castPtr p) b 16

isSuccess :: Ret -> Bool
isSuccess (Ret 0) = False
isSuccess (Ret 1) = True
isSuccess _ = undefined

{-# NOINLINE ctx #-}
ctx :: Ptr Ctx
ctx = unsafePerformIO $ do
    x <- context_create signVerify
    e <- getEntropy 32
    ret <- alloca $ \s -> poke s (Seed32 e) >> context_randomize x s
    unless (isSuccess ret) $ error "failed to randomize context"
    return x

foreign import ccall
    "secp256k1.h secp256k1_context_create"
    context_create
    :: CtxFlags
    -> IO (Ptr Ctx)

foreign import ccall
    "secp256k1.h secp256k1_context_clone"
    context_clone
    :: Ptr Ctx
    -> IO (Ptr Ctx)

foreign import ccall
    "secp256k1.h &secp256k1_context_destroy"
    context_destroy
    :: FunPtr (Ptr Ctx -> IO ())

foreign import ccall
    "secp256k1.h secp256k1_context_set_illegal_callback"
    set_illegal_callback
    :: Ptr Ctx
    -> FunPtr (CString -> Ptr a -> IO ()) -- ^ message, data
    -> Ptr a                              -- ^ data
    -> IO ()

foreign import ccall
    "secp256k1.h secp256k1_context_set_error_callback"
    set_error_callback
    :: Ptr Ctx
    -> FunPtr (CString -> Ptr a -> IO ()) -- ^ message, data
    -> Ptr a                              -- ^ data
    -> IO ()

foreign import ccall
    "secp256k1.h secp256k1_ec_pubkey_parse"
    ec_pubkey_parse
    :: Ptr Ctx
    -> Ptr PubKey64
    -> Ptr CUChar -- ^ encoded public key array
    -> CSize      -- ^ size of encoded public key array
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_pubkey_serialize"
    ec_pubkey_serialize
    :: Ptr Ctx
    -> Ptr CUChar -- ^ array for encoded public key, must be large enough
    -> Ptr CSize  -- ^ size of encoded public key, will be updated
    -> Ptr PubKey64
    -> SerFlags
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ecdsa_signature_parse_der"
    ecdsa_signature_parse_der
    :: Ptr Ctx
    -> Ptr Sig64
    -> Ptr CUChar -- ^ encoded DER signature
    -> CSize      -- ^ size of encoded signature
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ecdsa_signature_serialize_der"
    ecdsa_signature_serialize_der
    :: Ptr Ctx
    -> Ptr CUChar -- ^ array for encoded signature, must be large enough
    -> Ptr CSize  -- ^ size of encoded signature, will be updated
    -> Ptr Sig64
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ecdsa_verify"
    ecdsa_verify
    :: Ptr Ctx
    -> Ptr Sig64
    -> Ptr Msg32
    -> Ptr PubKey64
    -> IO Ret

-- TODO:
-- foreign import ccall
--     "secp256k1.h &secp256k1_nonce_function_rfc6979"
--     nonce_function_rfc6979
--     :: FunPtr (NonceFunction Seed32)
--
-- TODO:
-- foreign import ccall
--     "secp256k1.h &secp256k1_nonce_function_default"
--     nonce_function_default
--     :: FunPtr (NonceFunction Seed32)

foreign import ccall
    "secp256k1.h secp256k1_ecdsa_sign"
    ecdsa_sign
    :: Ptr Ctx
    -> Ptr Sig64
    -> Ptr Msg32
    -> Ptr SecKey32
    -> FunPtr (NonceFunction a)
    -> Ptr a -- ^ nonce data
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_seckey_verify"
    ec_seckey_verify
    :: Ptr Ctx
    -> Ptr SecKey32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_pubkey_create"
    ec_pubkey_create
    :: Ptr Ctx
    -> Ptr PubKey64
    -> Ptr SecKey32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_privkey_export"
    ec_privkey_export
    :: Ptr Ctx
    -> Ptr CUChar -- ^ array to store BER-encoded key (allocate 279 bytes)
    -> Ptr CSize -- ^ size of previous array, will be updated
    -> Ptr SecKey32
    -> SerFlags
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_privkey_import"
    ec_privkey_import
    :: Ptr Ctx
    -> Ptr SecKey32
    -> Ptr CUChar -- ^ BER-encoded private key
    -> CSize
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_privkey_tweak_add"
    ec_privkey_tweak_add
    :: Ptr Ctx
    -> Ptr SecKey32
    -> Ptr Tweak32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_pubkey_tweak_add"
    ec_pubkey_tweak_add
    :: Ptr Ctx
    -> Ptr PubKey64
    -> Ptr Tweak32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_privkey_tweak_mul"
    ec_privkey_tweak_mul
    :: Ptr Ctx
    -> Ptr SecKey32
    -> Ptr Tweak32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_pubkey_tweak_mul"
    ec_pubkey_tweak_mul
    :: Ptr Ctx
    -> Ptr PubKey64
    -> Ptr Tweak32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_context_randomize"
    context_randomize
    :: Ptr Ctx
    -> Ptr Seed32
    -> IO Ret

foreign import ccall
    "secp256k1.h secp256k1_ec_pubkey_combine"
    ec_pubkey_combine
    :: Ptr Ctx
    -> Ptr PubKey64 -- ^ pointer to public key storage
    -> Ptr (Ptr PubKey64) -- ^ pointer to array of public keys
    -> CInt -- ^ number of public keys
    -> IO Ret