{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE OverloadedStrings    #-}

module StrongSwan.SQL.Encoding where

import Control.Exception        (throw)
import Control.Lens             (over, each)
import Data.ASN1.BinaryEncoding (DER(..))
import Data.ASN1.Encoding       (encodeASN1', decodeASN1')
import Data.Bifunctor           (second)
import Data.ByteString.Char8    (ByteString, unpack)
import Data.Maybe               (fromJust)
import Data.Text                (Text)
import Data.Word                (Word8)
import Numeric                  (showHex)
import StrongSwan.SQL.Types
import Text.Read                (readMaybe)

import qualified Data.ByteString     as B
import qualified Data.Text           as T
import qualified Database.MySQL.Base as SQL

class SQLRow a where
    toValues   :: a -> [SQL.MySQLValue]
    fromValues :: [SQL.MySQLValue] -> a

class SQLValue a where
    toSQL :: a -> SQL.MySQLValue
    fromSQL ::  SQL.MySQLValue -> a

instance SQLValue (Value 'SQL.MySQLInt8U) where
    toSQL   (TinyInt x)        = SQL.MySQLInt8U x
    fromSQL (SQL.MySQLInt8U x) = TinyInt x
    fromSQL (SQL.MySQLInt8 x)  = TinyInt $ fromIntegral x
    fromSQL v                  = throw $ InvalidValueForType "TinyInt" (show v)

instance SQLValue (Value 'SQL.MySQLInt16U) where
    toSQL   (SmallInt x)        = SQL.MySQLInt16U x
    fromSQL (SQL.MySQLInt16U x) = SmallInt x
    fromSQL v                   = throw $ InvalidValueForType "SmallInt" (show v)


instance SQLValue (Value 'SQL.MySQLInt32U) where
    toSQL   (IntWord32 x)       = SQL.MySQLInt32U x
    fromSQL (SQL.MySQLInt32U x) = IntWord32 x
    fromSQL v                   = throw $ InvalidValueForType "IntWord32" (show v)

instance SQLValue (Value 'SQL.MySQLBytes) where
    toSQL (VarBinary bytes)        = SQL.MySQLBytes bytes
    fromSQL (SQL.MySQLBytes bytes) = VarBinary bytes
    fromSQL v                      = throw $ InvalidValueForType "VarBinary" (show v)

instance SQLValue (Value 'SQL.MySQLText) where
    toSQL (VarChar bytes) = SQL.MySQLText bytes
    toSQL NullChar        = SQL.MySQLNull
    fromSQL (SQL.MySQLText bytes) = VarChar bytes
    fromSQL SQL.MySQLNull         = NullChar
    fromSQL v                     = throw $ InvalidValueForType "VarChar" (show v)

fromId :: SQL.MySQLValue -> Maybe Int
fromId = return . fromInt . fromSQL

instance SQLRow Identity where
    toValues (AnyID _)                = [toSQL $ toTinyInt (0::Word8),  toSQL $ toVarBinary ("%any" :: String)]
    toValues (IPv4AddrID _ v4)        = [toSQL $ toTinyInt (1::Word8),  toSQL $ toVarBinary v4]
    toValues (NameID _ str)           = [toSQL $ toTinyInt (2::Word8),  toSQL $ toVarBinary str]
    toValues (EmailID _ local domain) = [toSQL $ toTinyInt (3::Word8),  toSQL $ toVarBinary (local <> "@" <> domain)]
    toValues (IPv6AddrID _ v6)        = [toSQL $ toTinyInt (5::Word8),  toSQL $ toVarBinary v6]
    toValues (ASN1ID _ elements)      = [toSQL $ toTinyInt (9::Word8),  toSQL $ toVarBinary (encodeASN1' DER elements)]
    toValues (OpaqueID _ bytes)       = [toSQL $ toTinyInt (11::Word8), toSQL $ toVarBinary bytes]

    fromValues [iD, SQL.MySQLInt8U 0, _] = AnyID      (fromId iD)
    fromValues [iD, SQL.MySQLInt8U 1, v] = IPv4AddrID (fromId iD) . fromVarBinary $ fromSQL v
    fromValues [iD, SQL.MySQLInt8U 2, v] = NameID     (fromId iD) . fromVarBinary $ fromSQL v
    fromValues [iD, SQL.MySQLInt8U 3, v] = uncurry (EmailID $ fromId iD)  . parseEmail . fromVarBinary $ fromSQL v
    fromValues [iD, SQL.MySQLInt8U 5, v] = IPv6AddrID (fromId iD) . fromVarBinary $ fromSQL v
    fromValues [iD, SQL.MySQLInt8U 9, v] = ASN1ID     (fromId iD) . either throw id . decodeASN1' DER . fromVarBinary $ fromSQL v
    fromValues v                         = throw $ SQLValuesMismatch "Identity" (show v)

parseEmail :: ByteString -> (Text, Text)
parseEmail str = over each T.pack . second (drop 1) $ span (/= '@') $ unpack str

instance SQLRow IKEConfig where
    toValues IKEConfig {..} = [
      toSQL $ toTinyInt _ikeReqCert,
      toSQL $ toTinyInt _ikeForceEncap,
      toSQL . toVarChar $ Just _ikeLocalAddress,
      toSQL . toVarChar $ Just _ikeRemoteAddress]
    fromValues (iD :
                reqCert :
                forceEncap :
                localAddress :
                remoteAddress :
                []) = IKEConfig {
                          _ikeId            = return . fromInt $ fromSQL iD,
                          _ikeReqCert       = fromTinyInt $ fromSQL reqCert,
                          _ikeForceEncap    = fromTinyInt $ fromSQL forceEncap,
                          _ikeLocalAddress  = fromJust . fromVarChar $ fromSQL localAddress,
                          _ikeRemoteAddress = fromJust . fromVarChar $ fromSQL remoteAddress
                      }
    fromValues xs = throw $ SQLValuesMismatch "IKEConfig" (show xs)


instance SQLRow ChildSAConfig where
    toValues ChildSAConfig {..} = [
      toSQL . toVarChar $ Just _childSAName,
      toSQL $ toInt _childSALifeTime,
      toSQL $ toInt _childSARekeyTime,
      toSQL $ toInt _childSAJitter,
      toSQL $ toVarChar _childSAUpDown,
      toSQL $ toTinyInt _childSAHostAccess,
      toSQL $ toTinyInt _childSAMode,
      toSQL $ toTinyInt _childSAStartAction,
      toSQL $ toTinyInt _childSADPDAction,
      toSQL $ toTinyInt _childSACloseAction,
      toSQL $ toTinyInt _childSAIPCompression,
      toSQL $ toInt     _childSAReqID,
      toSQL $ toVarChar _childSAMark ]
    fromValues (iD:
                name :
                lifeTime :
                rekeyTime :
                jitter :
                upDown :
                hostAccess :
                mode :
                startAction :
                dpdAction :
                closeAction :
                ipCompression :
                reqID :
                mark :
                []) = ChildSAConfig {
                           _childSAId            = return . fromInt       $ fromSQL iD,
                           _childSAName          = fromJust . fromVarChar $ fromSQL name,
                           _childSALifeTime      = fromInt     $ fromSQL lifeTime,
                           _childSARekeyTime     = fromInt     $ fromSQL rekeyTime,
                           _childSAJitter        = fromInt     $ fromSQL jitter,
                           _childSAUpDown        = fromVarChar $ fromSQL upDown,
                           _childSAHostAccess    = fromTinyInt $ fromSQL hostAccess,
                           _childSAMode          = fromTinyInt $ fromSQL mode,
                           _childSAStartAction   = fromTinyInt $ fromSQL startAction,
                           _childSADPDAction     = fromTinyInt $ fromSQL dpdAction,
                           _childSACloseAction   = fromTinyInt $ fromSQL closeAction,
                           _childSAIPCompression = fromTinyInt $ fromSQL ipCompression,
                           _childSAReqID         = fromInt     $ fromSQL reqID,
                           _childSAMark          = fromVarChar $ fromSQL mark
                      }
    fromValues xs = throw $ SQLValuesMismatch "ChildSAConfig" (show xs)

iDToVarChar :: Maybe Int -> SQL.MySQLValue
iDToVarChar = toSQL . toVarChar . (>>= return . T.pack . show)

varCharToId :: SQL.MySQLValue -> Maybe Int
varCharToId = (readMaybe . T.unpack =<<) . fromVarChar . fromSQL

instance SQLRow PeerConfig where
    toValues PeerConfig {..} = [
        toSQL . toVarChar  $ Just _peerCfgName,
        toSQL $ toTinyInt  _peerCfgIKEVersion,
        toSQL . toInt      $ fromJust _peerCfgIKEConfigId,
        iDToVarChar _peerCfgLocalId,
        iDToVarChar _peerCfgRemoteId,
        toSQL $ toTinyInt  _peerCfgCertPolicy,
        toSQL $ toTinyInt  _peerCfgUniqueIds,
        toSQL $ toTinyInt  _peerCfgAuthMethod,
        toSQL $ toTinyInt  _peerCfgEAPType,
        toSQL $ toSmallInt _peerCfgEAPVendor,
        toSQL $ toTinyInt  _peerCfgKeyingTries,
        toSQL $ toInt      _peerCfgRekeyTime,
        toSQL $ toInt      _peerCfgReauthTime,
        toSQL $ toInt      _peerCfgJitter,
        toSQL $ toInt      _peerCfgOverTime,
        toSQL $ toTinyInt  _peerCfgMobike,
        toSQL $ toInt      _peerCfgDPDDelay,
        toSQL $ toVarChar  _peerCfgVirtual,
        toSQL $ toVarChar  _peerCfgPool,
        toSQL $ toTinyInt  _peerCfgMediation,
        toSQL $ toInt      _peerCfgMediatedBy,
        toSQL $ toInt      _peerCfgPeerId ]
    fromValues (iD          :
                name        :
                ikeVersion  :
                ikeConfig   :
                localId     :
                remoteId    :
                certPolicy  :
                uniqueIds   :
                authMethod  :
                eapType     :
                eapVendor   :
                keyingTries :
                rekeyTime   :
                reauthTime  :
                jitter      :
                overTime    :
                mobike      :
                dpdDelay    :
                virtual     :
                pool        :
                mediation   :
                mediatedBy  :
                peerId      :
                []) = PeerConfig {
                           _peerCfgId          =  return . fromInt $ fromSQL iD,
                           _peerCfgName        =  fromJust . fromVarChar $ fromSQL name,
                           _peerCfgIKEVersion  =  fromTinyInt $ fromSQL ikeVersion,
                           _peerCfgIKEConfigId =  return . fromInt $ fromSQL ikeConfig,
                           _peerCfgLocalId     =  varCharToId localId,
                           _peerCfgRemoteId    =  varCharToId remoteId,
                           _peerCfgCertPolicy  =  fromTinyInt  $ fromSQL certPolicy,
                           _peerCfgUniqueIds   =  fromTinyInt  $ fromSQL uniqueIds,
                           _peerCfgAuthMethod  =  fromTinyInt  $ fromSQL authMethod,
                           _peerCfgEAPType     =  fromTinyInt  $ fromSQL eapType,
                           _peerCfgEAPVendor   =  fromSmallInt $ fromSQL eapVendor,
                           _peerCfgKeyingTries =  fromTinyInt  $ fromSQL keyingTries,
                           _peerCfgRekeyTime   =  fromInt      $ fromSQL rekeyTime,
                           _peerCfgReauthTime  =  fromInt      $ fromSQL reauthTime,
                           _peerCfgJitter      =  fromInt      $ fromSQL jitter,
                           _peerCfgOverTime    =  fromInt      $ fromSQL overTime,
                           _peerCfgMobike      =  fromTinyInt  $ fromSQL mobike,
                           _peerCfgDPDDelay    =  fromInt      $ fromSQL dpdDelay,
                           _peerCfgVirtual     =  fromVarChar  $ fromSQL virtual,
                           _peerCfgPool        =  fromVarChar  $ fromSQL pool,
                           _peerCfgMediation   =  fromTinyInt  $ fromSQL mediation,
                           _peerCfgMediatedBy  =  fromInt      $ fromSQL mediatedBy,
                           _peerCfgPeerId      =  fromInt      $ fromSQL peerId
                      }
    fromValues xs = throw $ SQLValuesMismatch "PeerConfig" (show xs)

instance SQLRow Peer2ChildConfig where
    toValues Peer2ChildConfig {..} = [
      toSQL $ toInt p2cPeerCfgId,
      toSQL $ toInt p2cChildCfgId ]

    fromValues (peerId: childId: []) =
      Peer2ChildConfig {
          p2cPeerCfgId  = fromInt $ fromSQL peerId,
          p2cChildCfgId = fromInt $ fromSQL childId
      }
    fromValues xs = throw $ SQLValuesMismatch "Peer2ChildConfig" (show xs)

instance SQLRow TrafficSelector where
    toValues TrafficSelector {..} = [
      toSQL $ toTinyInt  _tsType,
      toSQL $ toSmallInt _tsProtocol,
      toSQL $ toVarBinary _tsStartAddr,
      toSQL $ toVarBinary _tsEndAddr,
      toSQL $ toSmallInt _tsStartPort,
      toSQL $ toSmallInt _tsEndPort ]
    fromValues (iD        :
                type'     :
                protocol  :
                startAddr :
                endAddr   :
                startPort :
                endort    :
                []) = TrafficSelector {
                          _tsId        = return . fromInt $ fromSQL iD,
                          _tsType      = fromTinyInt      $ fromSQL type',
                          _tsProtocol  = fromSmallInt     $ fromSQL protocol,
                          _tsStartAddr = fromVarBinary    $ fromSQL startAddr,
                          _tsEndAddr   = fromVarBinary    $ fromSQL endAddr,
                          _tsStartPort = fromSmallInt     $ fromSQL startPort,
                          _tsEndPort   = fromSmallInt     $ fromSQL endort
                      }
    fromValues xs = throw $ SQLValuesMismatch "TrafficSelector" (show xs)

instance SQLRow Child2TSConfig where
    toValues Child2TSConfig {..} = [
      toSQL $ toInt     c2tsChildCfgId,
      toSQL $ toInt     c2tsTrafficSelectorCfgId,
      toSQL $ toTinyInt c2tsTrafficSelectorKind ]

    fromValues (childCfgId           :
                trafficSelectorCfgId :
                trafficSelectorKind  :
                []) = Child2TSConfig {
                            c2tsChildCfgId           = fromInt     $ fromSQL childCfgId,
                            c2tsTrafficSelectorCfgId = fromInt     $ fromSQL trafficSelectorCfgId,
                            c2tsTrafficSelectorKind  = fromTinyInt $ fromSQL trafficSelectorKind
                      }
    fromValues xs = throw $ SQLValuesMismatch "Child2TSConfig" (show xs)

instance SQLRow SharedSecret where
    toValues SharedSecret {..} = [
        toSQL $ toTinyInt _ssType,
        toSQL $ toVarBinary _ssData ]
    fromValues (iD : sharedSecretType : sharedSecretData : []) =
        SharedSecret {
              _ssId   = return . fromInt $ fromSQL iD,
              _ssType = fromTinyInt $ fromSQL sharedSecretType,
              _ssData = fromVarBinary $ fromSQL sharedSecretData
        }
    fromValues xs = throw $ SQLValuesMismatch "SharedSecret" (show xs)

instance SQLRow SharedSecretIdentity where
    toValues SharedSecretIdentity {..} = toSQL . toInt <$> [_sharedSecretId, _identityId]
    fromValues [ssId, identityId] = SharedSecretIdentity {
                                          _sharedSecretId = fromInt $ fromSQL ssId,
                                          _identityId     = fromInt $ fromSQL identityId
                                    }
    fromValues xs = throw $ SQLValuesMismatch "SharedSecretIdentity" (show xs)


encodeHex :: ByteString -> String
encodeHex = B.foldr showHex ""