{-# LANGUAGE OverloadedStrings #-}

module StrongSwan.SQL.Statements where

import Control.Exception         (fromException)
import Control.Monad             (mapM_,void)
import Control.Monad.IO.Class    (MonadIO)
import Control.Monad.Fail        (MonadFail)
import Control.Monad.Failable
import Database.MySQL.Base       (MySQLConn)
import StrongSwan.SQL.Types

import qualified Database.MySQL.Base as SQL

prepareStatements :: (Failable m, MonadIO m, MonadFail m) => SQL.MySQLConn -> m (PreparedStatements SQL.StmtID)
prepareStatements conn = do
    let ?conn = conn
    [createChildSA, updateChildSA, findChildSA, deleteChildSA] <-
        initializeWith createChildSATableStatement
                       [createChildSAStatement,
                        updateChildSAStatement,
                        findChildSAStatement,
                        deleteChildSAStatement]

    findChildSAByName <- prepare findChildSAByNameStatement

    [createIKE, updateIKE, findIKE, deleteIKE] <-
        initializeWith createIKETableStatement
                       [createIKEStatement,
                        updateIKEStatement,
                        findIKEStatement,
                        deleteIKEStatement]

    [createPeer, updatePeer, findPeer, deletePeer] <-
        initializeWith createPeerTableStatement
                       [createPeerStatement,
                        updatePeerStatement,
                        findPeerStatement,
                        deletePeerStatement]

    findPeerByName <- prepare findPeerByNameStatement

    [createP2C, updateP2C, findP2C, deleteP2C] <-
        initializeWith createP2CTableStatement
                       [createP2CStatement,
                        updateP2CStatement,
                        findP2CStatement,
                        deleteP2CStatement]

    [createTS, updateTS, findTS, deleteTS] <-
        initializeWith createTSTableStatement
                       [createTSStatement,
                        updateTSStatement,
                        findTSStatement,
                        deleteTSStatement]

    [createC2TS, updateC2TS, findC2TS, deleteC2TS] <-
        initializeWith createC2TSTableStatement
                       [createC2TSStatement,
                        updateC2TSStatement,
                        findC2TSStatement,
                        deleteC2TSStatement]

    [createIdentity, updateIdentity, findIdentity, findIdentityBySelf, deleteIdentity] <-
        initializeWith createIdentityTable
                        [createIdentityStatement,
                         updateIdentityStatement,
                         findIdentityStatement,
                         findIdentityBySelfStatement,
                         deleteIdentityStatement]

    [createSharedSecret, updateSharedSecret, findSharedSecret, deleteSharedSecret] <-
        initializeWith createSharedSecretTable
                        [createSharedSecretStatement,
                         updateSharedSecretStatement,
                         findSharedSecretStatement,
                         deleteSharedSecretStatement]

    [createSSIdentity, updateSSIdentity, findSSIdentity, deleteSSIdentity] <-
        initializeWith createSSIdentityTable
                        [createSSIdentityStatement,
                         updateSSIdentityStatement,
                         findSSIdentityStatement,
                         deleteSSIdentityStatement]

    [createIPSec, findIPSec, deleteIPSec] <-
        initializeWith createIPSecTableStatement
                       [createIPSecStatement,
                        findIPSecStatement,
                        deleteIPSecStatement]

    mapM_ (`initializeWith` []) [createCertificatesTableStatement,
                                 createProposalsTableStatement]

    return PreparedStatements  {
        updateChildSAStmt      = updateChildSA,
        createChildSAStmt      = createChildSA,
        findChildSAByNameStmt  = findChildSAByName,
        findChildSAStmt        = findChildSA,
        deleteChildSAStmt      = deleteChildSA,
        updateIKEStmt          = updateIKE,
        createIKEStmt          = createIKE,
        findIKEStmt            = findIKE,
        deleteIKEStmt          = deleteIKE,
        updatePeerStmt         = updatePeer,
        createPeerStmt         = createPeer,
        findPeerStmt           = findPeer,
        findPeerByNameStmt     = findPeerByName,
        deletePeerStmt         = deletePeer,
        updateP2CStmt          = updateP2C,
        createP2CStmt          = createP2C,
        findP2CStmt            = findP2C,
        deleteP2CStmt          = deleteP2C,
        findTSStmt             = findTS,
        createTSStmt           = createTS,
        updateTSStmt           = updateTS,
        deleteTSStmt           = deleteTS,
        updateC2TSStmt         = updateC2TS,
        createC2TSStmt         = createC2TS,
        findC2TSStmt           = findC2TS,
        deleteC2TSStmt         = deleteC2TS,
        updateIdentityStmt     = updateIdentity,
        createIdentityStmt     = createIdentity,
        findIdentityStmt       = findIdentity,
        findIdentityBySelfStmt = findIdentityBySelf,
        deleteIdentityStmt     = deleteIdentity,
        updateSharedSecretStmt = updateSharedSecret,
        createSharedSecretStmt = createSharedSecret,
        findSharedSecretStmt   = findSharedSecret,
        deleteSharedSecretStmt = deleteSharedSecret,
        updateSSIdentityStmt   = updateSSIdentity,
        createSSIdentityStmt   = createSSIdentity,
        findSSIdentityStmt     = findSSIdentity,
        deleteSSIdentityStmt   = deleteSSIdentity,
        createIPSecStmt        = createIPSec,
        findIPSecStmt          = findIPSec,
        deleteIPSecStmt        = deleteIPSec
    }

prepare :: (?conn :: MySQLConn, Failable m, MonadIO m) => SQL.Query -> m SQL.StmtID
prepare = failableIO . SQL.prepareStmt ?conn

updateIKEStatement :: SQL.Query
updateIKEStatement = "UPDATE ike_configs SET certreq = ?, force_encap = ?, local = ?, remote = ? WHERE id = ?;"

createIKETableStatement :: SQL.Query
createIKETableStatement = "CREATE TABLE `ike_configs` ( `id` int(10) unsigned NOT NULL auto_increment, `certreq` tinyint(3) unsigned NOT NULL default '1', `force_encap` tinyint(1) NOT NULL default '0', `local` varchar(128) collate utf8_unicode_ci NOT NULL, `remote` varchar(128) collate utf8_unicode_ci NOT NULL, PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createIKEStatement :: SQL.Query
createIKEStatement = "INSERT INTO ike_configs (certreq, force_encap, local, remote) VAlUES (?, ?, ?, ?);"

findIKEStatement :: SQL.Query
findIKEStatement = "SELECT * FROM ike_configs WHERE id = ?;"

deleteIKEStatement :: SQL.Query
deleteIKEStatement = "DELETE FROM ike_configs WHERE id = ?;"

updateChildSAStatement :: SQL.Query
updateChildSAStatement = "UPDATE child_configs SET name = ?, lifetime = ?, rekeytime = ?, jitter = ?, updown = ?, hostaccess = ?, mode = ?, start_action = ?, dpd_action = ?, close_action = ?, ipcomp = ?, reqid = ?, mark = ? WHERE id = ?"

createChildSATableStatement :: SQL.Query
createChildSATableStatement = "CREATE TABLE `child_configs` ( `id` int(10) unsigned NOT NULL auto_increment, `name` varchar(32) collate utf8_unicode_ci NOT NULL, `lifetime` mediumint(8) unsigned NOT NULL default '1500', `rekeytime` mediumint(8) unsigned NOT NULL default '1200', `jitter` mediumint(8) unsigned NOT NULL default '60', `updown` varchar(128) collate utf8_unicode_ci default NULL, `hostaccess` tinyint(1) unsigned NOT NULL default '0', `mode` tinyint(4) unsigned NOT NULL default '2', `start_action` tinyint(4) unsigned NOT NULL default '0', `dpd_action` tinyint(4) unsigned NOT NULL default '0', `close_action` tinyint(4) unsigned NOT NULL default '0', `ipcomp` tinyint(4) unsigned NOT NULL default '0', `reqid` mediumint(8) unsigned NOT NULL default '0', `mark` varchar(32) collate utf8_unicode_ci default NULL, PRIMARY KEY (`id`), INDEX (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createChildSAStatement :: SQL.Query
createChildSAStatement = "INSERT INTO child_configs (name, lifetime, rekeytime, jitter, updown, hostaccess, mode, start_action, dpd_action, close_action, ipcomp, reqid, mark) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"

findChildSAByNameStatement :: SQL.Query
findChildSAByNameStatement = "SELECT * FROM child_configs WHERE name = ?;"

findChildSAStatement :: SQL.Query
findChildSAStatement = "SELECT * FROM child_configs WHERE id = ?;"

deleteChildSAStatement :: SQL.Query
deleteChildSAStatement = "DELETE FROM child_configs WHERE id = ?;"

createPeerTableStatement :: SQL.Query
createPeerTableStatement = "CREATE TABLE `peer_configs` ( `id` int(10) unsigned NOT NULL auto_increment, `name` varchar(32) collate utf8_unicode_ci NOT NULL, `ike_version` tinyint(3) unsigned NOT NULL default '2', `ike_cfg` int(10) unsigned NOT NULL, `local_id` varchar(64) collate utf8_unicode_ci NOT NULL, `remote_id` varchar(64) collate utf8_unicode_ci NOT NULL, `cert_policy` tinyint(3) unsigned NOT NULL default '1', `uniqueid` tinyint(3) unsigned NOT NULL default '0', `auth_method` tinyint(3) unsigned NOT NULL default '1', `eap_type` tinyint(3) unsigned NOT NULL default '0', `eap_vendor` smallint(5) unsigned NOT NULL default '0', `keyingtries` tinyint(3) unsigned NOT NULL default '3', `rekeytime` mediumint(8) unsigned NOT NULL default '7200', `reauthtime` mediumint(8) unsigned NOT NULL default '0', `jitter` mediumint(8) unsigned NOT NULL default '180', `overtime` mediumint(8) unsigned NOT NULL default '300', `mobike` tinyint(1) NOT NULL default '1', `dpd_delay` mediumint(8) unsigned NOT NULL default '120', `virtual` varchar(40) default NULL, `pool` varchar(32) default NULL, `mediation` tinyint(1) NOT NULL default '0', `mediated_by` int(10) unsigned NOT NULL default '0', `peer_id` int(10) unsigned NOT NULL default '0', PRIMARY KEY (`id`), INDEX (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createPeerStatement :: SQL.Query
createPeerStatement = "INSERT INTO peer_configs (`name`, `ike_version`, `ike_cfg`, `local_id`, `remote_id`, `cert_policy`, `uniqueid`, `auth_method`, `eap_type`, `eap_vendor`, `keyingtries`, `rekeytime`, `reauthtime`, `jitter`, `overtime`, `mobike`, `dpd_delay`, `virtual`, `pool`, `mediation`, `mediated_by`, `peer_id` ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"

updatePeerStatement :: SQL.Query
updatePeerStatement = "UPDATE peer_configs SET `name` = ?, `ike_version` = ?, `ike_cfg` = ?, `local_id` = ?, `remote_id` = ?, `cert_policy` = ?, `uniqueid` = ?, `auth_method` = ?, `eap_type` = ?, `eap_vendor` = ?, `keyingtries` = ?, `rekeytime` = ?, `reauthtime` = ?, `jitter` = ?, `overtime` = ?, `mobike` = ?, `dpd_delay` = ?, `virtual` = ?, `pool` = ?, `mediation` = ?, `mediated_by` = ?, `peer_id`  = ? WHERE `id` = ?;"

findPeerStatement :: SQL.Query
findPeerStatement = "SELECT * from peer_configs WHERE id = ?;"

findPeerByNameStatement :: SQL.Query
findPeerByNameStatement = "SELECT * from peer_configs WHERE name = ?;"

deletePeerStatement :: SQL.Query
deletePeerStatement = "DELETE from peer_configs WHERE id = ?;"

createP2CTableStatement :: SQL.Query
createP2CTableStatement = "CREATE TABLE `peer_config_child_config` (`peer_cfg` int(10) unsigned NOT NULL, `child_cfg` int(10) unsigned NOT NULL, PRIMARY KEY (`peer_cfg`, `child_cfg`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createP2CStatement :: SQL.Query
createP2CStatement = "INSERT INTO peer_config_child_config (peer_cfg, child_cfg) VALUES (?, ?);"

updateP2CStatement :: SQL.Query
updateP2CStatement = "UPDATE peer_config_child_config SET peer_cfg = ?, child_cfg = ? WHERE peer_cfg = ? AND child_cfg = ?;"

findP2CStatement :: SQL.Query
findP2CStatement = "SELECT * from peer_config_child_config WHERE peer_cfg = ? AND child_cfg = ?;"

deleteP2CStatement :: SQL.Query
deleteP2CStatement = "DELETE from peer_config_child_config WHERE peer_cfg = ? AND child_cfg = ?";

createTSTableStatement :: SQL.Query
createTSTableStatement = "CREATE TABLE `traffic_selectors` (`id` int(10) unsigned NOT NULL auto_increment, `type` tinyint(3) unsigned NOT NULL default '7', `protocol` smallint(5) unsigned NOT NULL default '0', `start_addr` varbinary(16) default NULL, `end_addr` varbinary(16) default NULL, `start_port` smallint(5) unsigned NOT NULL default '0', `end_port` smallint(5) unsigned NOT NULL default '65535', PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createTSStatement :: SQL.Query
createTSStatement = "INSERT INTO traffic_selectors (type, protocol, start_addr, end_addr, start_port, end_port) VALUES ( ?, ?, ?, ?, ?, ?);"

updateTSStatement :: SQL.Query
updateTSStatement = "UPDATE traffic_selectors SET type = ?, protocol = ?, start_addr = ?, end_addr = ?, start_port = ?, end_port = ? WHERE id = ?;"

findTSStatement :: SQL.Query
findTSStatement = "SELECT * FROM traffic_selectors WHERE id = ?;"

deleteTSStatement :: SQL.Query
deleteTSStatement = "DELETE FROM traffic_selectors WHERE id = ?";

createC2TSTableStatement :: SQL.Query
createC2TSTableStatement = "CREATE TABLE `child_config_traffic_selector` (`child_cfg` int(10) unsigned NOT NULL, `traffic_selector` int(10) unsigned NOT NULL, `kind` tinyint(3) unsigned NOT NULL, INDEX (`child_cfg`, `traffic_selector`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createC2TSStatement :: SQL.Query
createC2TSStatement = "INSERT INTO child_config_traffic_selector (child_cfg, traffic_selector, kind) VALUES (?, ?, ?);"

updateC2TSStatement :: SQL.Query
updateC2TSStatement = "UPDATE child_config_traffic_selector SET child_cfg = ?, traffic_selector = ?, kind = ? WHERE child_cfg = ? AND traffic_selector = ?;"

findC2TSStatement :: SQL.Query
findC2TSStatement = "SELECT * FROM child_config_traffic_selector WHERE child_cfg = ?;"

deleteC2TSStatement :: SQL.Query
deleteC2TSStatement = "DELETE FROM child_config_traffic_selector WHERE child_cfg = ?;"

createSharedSecretTable :: SQL.Query
createSharedSecretTable = "CREATE TABLE shared_secrets (`id` int(10) unsigned NOT NULL auto_increment, `type` tinyint(3) unsigned NOT NULL, `data` varbinary(256) NOT NULL, PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createSharedSecretStatement :: SQL.Query
createSharedSecretStatement = "INSERT INTO shared_secrets (type, data) VALUES (?, ?);"

updateSharedSecretStatement :: SQL.Query
updateSharedSecretStatement = "UPDATE shared_secrets SET type = ?, data = ? WHERE id = ?;"

findSharedSecretStatement :: SQL.Query
findSharedSecretStatement = "SELECT * FROM shared_secrets WHERE id = ?;"

deleteSharedSecretStatement :: SQL.Query
deleteSharedSecretStatement = "DELETE FROM shared_secrets WHERE id = ?;"

createIdentityTable :: SQL.Query
createIdentityTable = "CREATE TABLE `identities` (`id` int(10) unsigned NOT NULL auto_increment, `type` tinyint(4) unsigned NOT NULL, `data` varbinary(64) NOT NULL, PRIMARY KEY (`id`), UNIQUE (`type`, `data`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createIdentityStatement  :: SQL.Query
createIdentityStatement = "INSERT INTO identities (type, data) VALUES (?, ?);"

updateIdentityStatement :: SQL.Query
updateIdentityStatement = "UPDATE identities SET type = ?, data = ? WHERE id = ?;"

findIdentityStatement :: SQL.Query
findIdentityStatement = "SELECT * FROM identities WHERE id = ?;"

findIdentityBySelfStatement :: SQL.Query
findIdentityBySelfStatement = "SELECT * FROM identities WHERE type = ? AND data = ?;"

deleteIdentityStatement :: SQL.Query
deleteIdentityStatement = "DELETE FROM identities WHERE id = ?;"

createSSIdentityTable  :: SQL.Query
createSSIdentityTable = "CREATE TABLE shared_secret_identity (`shared_secret` int(10) unsigned NOT NULL, `identity` int(10) unsigned NOT NULL, PRIMARY KEY (`shared_secret`, `identity`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createSSIdentityStatement :: SQL.Query
createSSIdentityStatement = "INSERT INTO shared_secret_identity (shared_secret, identity) VALUES (?, ?);"

updateSSIdentityStatement :: SQL.Query
updateSSIdentityStatement = "UPDATE shared_secret_identity SET shared_secret = ?, identity = ? WHERE shared_secret = ? AND identity = ?;"

findSSIdentityStatement :: SQL.Query
findSSIdentityStatement = "SELECT * from shared_secret_identity WHERE identity = ?;"

deleteSSIdentityStatement :: SQL.Query
deleteSSIdentityStatement = "DELETE FROM shared_secret_identity WHERE shared_secret = ? AND identity = ?;"

createIPSecTableStatement :: SQL.Query
createIPSecTableStatement = "CREATE TABLE `ipsec_configs` (`name` varchar(64) NOT NULL, `child_cfg` int(10) unsigned NOT NULL, `peer_cfg` int(10) unsigned NOT NULL, `ike_cfg` int(10) unsigned NOT NULL, `local_ts` int(10) unsigned NOT NULL, `remote_ts` int(10) unsigned NOT NULL, `local_id` int(10) unsigned NOT NULL, `remote_id` int(10) unsigned NOT NULL, PRIMARY KEY (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createIPSecStatement :: SQL.Query
createIPSecStatement = "INSERT INTO ipsec_configs (name, child_cfg, peer_cfg, ike_cfg, local_ts, remote_ts, local_id, remote_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?);"

findIPSecStatement :: SQL.Query
findIPSecStatement = "SELECT * FROM ipsec_configs WHERE name = ?;"

deleteIPSecStatement :: SQL.Query
deleteIPSecStatement = "DELETE FROM ipsec_configs WHERE name = ?;"

createProposalsTableStatement :: SQL.Query
createProposalsTableStatement = "CREATE TABLE `proposals` (`id` int(10) unsigned NOT NULL auto_increment, `proposal` varchar(128) NOT NULL, PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createCertificatesTableStatement :: SQL.Query
createCertificatesTableStatement = "CREATE TABLE certificates (`id` int(10) unsigned NOT NULL auto_increment, `type` tinyint(3) unsigned NOT NULL, `keytype` tinyint(3) unsigned NOT NULL, `data` BLOB NOT NULL,  PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"


initializeWith :: (Failable m, MonadIO m, ?conn::SQL.MySQLConn)
                => SQL.Query
                -> [SQL.Query]
                -> m [SQL.StmtID]
initializeWith createTableStatement operStatements = do
    createTable  <- prepare createTableStatement
    -- create config table before any other operations on it. If it already exists.. whatever
    (void . failableIO $ SQL.executeStmt ?conn createTable [])
                            `recover` \e ->
                                case fromException e of
                                  Just (SQL.ERRException SQL.ERR{..}) | errCode == 1050 ->
                                    -- ignore table already exists errors
                                    return ()
                                  _ ->
                                    failure e

    mapM prepare operStatements