module Crypto.Signature
  ( signParams
  , signJSON
  , hmacSHA256
  , signRaw

  , mkHexHash
  , signParams_
  , signJSON_
  , signRaw_
  ) where

import           Crypto.Hash           (Digest, SHA256)
import           Crypto.MAC            (HMAC (..), hmac)
import           Data.Aeson            (Value (..))
import           Data.Byteable         (toBytes)
import qualified Data.ByteString.Char8 as B (ByteString, concat, empty, pack,
                                             unpack)
import           Data.CaseInsensitive  (CI, mk)
import qualified Data.HashMap.Lazy     as LH (HashMap, toList)
import           Data.HexString        (fromBytes, toText)
import           Data.List             (sortOn)
import           Data.Scientific       (Scientific, floatingOrInteger)
import qualified Data.Text             as T (Text, unpack)
import           Data.Text.Encoding    (encodeUtf8)
import qualified Data.Text.Lazy        as LT (Text, toStrict, unpack)
import qualified Data.Vector           as V (Vector, toList)

-- | Make a case-insensitive hex hash string by hmac sha256
hmacSHA256 :: B.ByteString -> B.ByteString -> CI B.ByteString
hmacSHA256 :: ByteString -> ByteString -> CI ByteString
hmacSHA256 solt :: ByteString
solt = (ByteString -> Digest SHA256) -> ByteString -> CI ByteString
forall a. (ByteString -> Digest a) -> ByteString -> CI ByteString
mkHexHash (ByteString -> ByteString -> Digest SHA256
mkHmacSHA256Hash ByteString
solt)

mkHmacSHA256Hash :: B.ByteString -> B.ByteString -> Digest SHA256
mkHmacSHA256Hash :: ByteString -> ByteString -> Digest SHA256
mkHmacSHA256Hash solt :: ByteString
solt = HMAC SHA256 -> Digest SHA256
forall a. HMAC a -> Digest a
hmacGetDigest (HMAC SHA256 -> Digest SHA256)
-> (ByteString -> HMAC SHA256) -> ByteString -> Digest SHA256
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> HMAC SHA256
forall a. HashAlgorithm a => ByteString -> ByteString -> HMAC a
hmac ByteString
solt

-- | Make a case-insensitive hex hash string by a hash function
mkHexHash :: (B.ByteString -> Digest a) -> B.ByteString -> CI B.ByteString
mkHexHash :: (ByteString -> Digest a) -> ByteString -> CI ByteString
mkHexHash mkHash :: ByteString -> Digest a
mkHash = ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk (ByteString -> CI ByteString)
-> (ByteString -> ByteString) -> ByteString -> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8 (Text -> ByteString)
-> (ByteString -> Text) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HexString -> Text
toText (HexString -> Text)
-> (ByteString -> HexString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> HexString
fromBytes (ByteString -> HexString)
-> (ByteString -> ByteString) -> ByteString -> HexString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest a -> ByteString
forall a. Byteable a => a -> ByteString
toBytes (Digest a -> ByteString)
-> (ByteString -> Digest a) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Digest a
mkHash

sortAndJoinTextParams :: [(LT.Text, LT.Text)] -> B.ByteString
sortAndJoinTextParams :: [(Text, Text)] -> ByteString
sortAndJoinTextParams = [(Text, Text)] -> ByteString
join ([(Text, Text)] -> ByteString)
-> ([(Text, Text)] -> [(Text, Text)])
-> [(Text, Text)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Text, Text)] -> [(Text, Text)]
sort
  where sort :: [(LT.Text, LT.Text)] -> [(LT.Text, LT.Text)]
        sort :: [(Text, Text)] -> [(Text, Text)]
sort = ((Text, Text) -> String) -> [(Text, Text)] -> [(Text, Text)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (\(k :: Text
k, _) -> Text -> String
LT.unpack Text
k)

        join :: [(LT.Text, LT.Text)] -> B.ByteString
        join :: [(Text, Text)] -> ByteString
join ((k :: Text
k,v :: Text
v):xs :: [(Text, Text)]
xs) = [ByteString] -> ByteString
B.concat [Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> Text
LT.toStrict Text
k, Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> Text
LT.toStrict Text
v, [(Text, Text)] -> ByteString
join [(Text, Text)]
xs]
        join []         = ByteString
B.empty

-- | Sign a text params use a hash function
signParams_ :: (B.ByteString -> Digest a) -> [(LT.Text, LT.Text)] -> CI B.ByteString
signParams_ :: (ByteString -> Digest a) -> [(Text, Text)] -> CI ByteString
signParams_ mkHash :: ByteString -> Digest a
mkHash = (ByteString -> Digest a) -> ByteString -> CI ByteString
forall a. (ByteString -> Digest a) -> ByteString -> CI ByteString
mkHexHash ByteString -> Digest a
mkHash (ByteString -> CI ByteString)
-> ([(Text, Text)] -> ByteString)
-> [(Text, Text)]
-> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Text, Text)] -> ByteString
sortAndJoinTextParams

-- | Sign a text params use hmac sha256
signParams :: B.ByteString -> [(LT.Text, LT.Text)] -> CI B.ByteString
signParams :: ByteString -> [(Text, Text)] -> CI ByteString
signParams solt :: ByteString
solt = (ByteString -> Digest SHA256) -> [(Text, Text)] -> CI ByteString
forall a.
(ByteString -> Digest a) -> [(Text, Text)] -> CI ByteString
signParams_ (ByteString -> ByteString -> Digest SHA256
mkHmacSHA256Hash ByteString
solt)

sortAndJoinJSON :: Value -> B.ByteString
sortAndJoinJSON :: Value -> ByteString
sortAndJoinJSON = Value -> ByteString
v2b
  where sortHashMap :: LH.HashMap T.Text Value -> [(T.Text, Value)]
        sortHashMap :: HashMap Text Value -> [(Text, Value)]
sortHashMap = ((Text, Value) -> String) -> [(Text, Value)] -> [(Text, Value)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (\(k :: Text
k, _) -> Text -> String
T.unpack Text
k) ([(Text, Value)] -> [(Text, Value)])
-> (HashMap Text Value -> [(Text, Value)])
-> HashMap Text Value
-> [(Text, Value)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Text Value -> [(Text, Value)]
forall k v. HashMap k v -> [(k, v)]
LH.toList

        joinList :: [(T.Text, Value)] -> B.ByteString
        joinList :: [(Text, Value)] -> ByteString
joinList []          = ByteString
B.empty
        joinList ((k :: Text
k, v :: Value
v):xs :: [(Text, Value)]
xs) = [ByteString] -> ByteString
B.concat [Text -> ByteString
encodeUtf8 Text
k, Value -> ByteString
v2b Value
v, [(Text, Value)] -> ByteString
joinList [(Text, Value)]
xs]

        joinArray :: V.Vector Value -> B.ByteString
        joinArray :: Vector Value -> ByteString
joinArray = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString)
-> (Vector Value -> [ByteString]) -> Vector Value -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Value -> ByteString) -> [Value] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Value -> ByteString
v2b ([Value] -> [ByteString])
-> (Vector Value -> [Value]) -> Vector Value -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Value -> [Value]
forall a. Vector a -> [a]
V.toList

        v2b :: Value -> B.ByteString
        v2b :: Value -> ByteString
v2b (Object v :: HashMap Text Value
v)   = ([(Text, Value)] -> ByteString
joinList ([(Text, Value)] -> ByteString)
-> (HashMap Text Value -> [(Text, Value)])
-> HashMap Text Value
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Text Value -> [(Text, Value)]
sortHashMap) HashMap Text Value
v
        v2b (Array v :: Vector Value
v)    = Vector Value -> ByteString
joinArray Vector Value
v
        v2b (String v :: Text
v)   = Text -> ByteString
encodeUtf8 Text
v
        v2b (Number v :: Scientific
v)   = String -> ByteString
B.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ Scientific -> String
showNumber Scientific
v
        v2b (Bool True)  = String -> ByteString
B.pack "true"
        v2b (Bool False) = String -> ByteString
B.pack "false"
        v2b Null         = ByteString
B.empty

        showNumber :: Scientific -> String
        showNumber :: Scientific -> String
showNumber v :: Scientific
v = case Scientific -> Either Double Integer
forall r i. (RealFloat r, Integral i) => Scientific -> Either r i
floatingOrInteger Scientific
v of
                         Left n :: Double
n  -> Double -> String
forall a. Show a => a -> String
show Double
n
                         Right n :: Integer
n -> Integer -> String
forall a. Show a => a -> String
show Integer
n

-- | Sign JSON data use a hash function
signJSON_ :: (B.ByteString -> Digest a) -> Value -> CI B.ByteString
signJSON_ :: (ByteString -> Digest a) -> Value -> CI ByteString
signJSON_ mkHash :: ByteString -> Digest a
mkHash = (ByteString -> Digest a) -> ByteString -> CI ByteString
forall a. (ByteString -> Digest a) -> ByteString -> CI ByteString
mkHexHash ByteString -> Digest a
mkHash (ByteString -> CI ByteString)
-> (Value -> ByteString) -> Value -> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> ByteString
sortAndJoinJSON

-- | Sign JSON data use hmac sha256
signJSON :: B.ByteString -> Value -> CI B.ByteString
signJSON :: ByteString -> Value -> CI ByteString
signJSON solt :: ByteString
solt = (ByteString -> Digest SHA256) -> Value -> CI ByteString
forall a. (ByteString -> Digest a) -> Value -> CI ByteString
signJSON_ (ByteString -> ByteString -> Digest SHA256
mkHmacSHA256Hash ByteString
solt)

sortAndJoinRawParams :: [(B.ByteString, B.ByteString)] -> B.ByteString
sortAndJoinRawParams :: [(ByteString, ByteString)] -> ByteString
sortAndJoinRawParams = [(ByteString, ByteString)] -> ByteString
join ([(ByteString, ByteString)] -> ByteString)
-> ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(ByteString, ByteString)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> [(ByteString, ByteString)]
sort
  where sort :: [(B.ByteString, B.ByteString)] -> [(B.ByteString, B.ByteString)]
        sort :: [(ByteString, ByteString)] -> [(ByteString, ByteString)]
sort = ((ByteString, ByteString) -> String)
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (\(k :: ByteString
k, _) -> ByteString -> String
B.unpack ByteString
k)

        join :: [(B.ByteString, B.ByteString)] -> B.ByteString
        join :: [(ByteString, ByteString)] -> ByteString
join []         = ByteString
B.empty
        join ((k :: ByteString
k,v :: ByteString
v):xs :: [(ByteString, ByteString)]
xs) = [ByteString] -> ByteString
B.concat [ByteString
k, ByteString
v, [(ByteString, ByteString)] -> ByteString
join [(ByteString, ByteString)]
xs]

-- | Sign bytestring params use a hash function
signRaw_ :: (B.ByteString -> Digest a) -> [(B.ByteString, B.ByteString)] -> CI B.ByteString
signRaw_ :: (ByteString -> Digest a)
-> [(ByteString, ByteString)] -> CI ByteString
signRaw_ mkHash :: ByteString -> Digest a
mkHash = (ByteString -> Digest a) -> ByteString -> CI ByteString
forall a. (ByteString -> Digest a) -> ByteString -> CI ByteString
mkHexHash ByteString -> Digest a
mkHash (ByteString -> CI ByteString)
-> ([(ByteString, ByteString)] -> ByteString)
-> [(ByteString, ByteString)]
-> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> ByteString
sortAndJoinRawParams

-- | Sign bytestring params use hmac sha256
signRaw :: B.ByteString -> [(B.ByteString, B.ByteString)] -> CI B.ByteString
signRaw :: ByteString -> [(ByteString, ByteString)] -> CI ByteString
signRaw solt :: ByteString
solt = (ByteString -> Digest SHA256)
-> [(ByteString, ByteString)] -> CI ByteString
forall a.
(ByteString -> Digest a)
-> [(ByteString, ByteString)] -> CI ByteString
signRaw_ (ByteString -> ByteString -> Digest SHA256
mkHmacSHA256Hash ByteString
solt)