{-# LANGUAGE OverloadedStrings #-}

module StrongSwan.SQL.Statements where

import Control.Exception         (fromException)
import Control.Monad             (void)
import Control.Monad.IO.Class    (MonadIO)
import Control.Monad.Fail        (MonadFail)
import Control.Monad.Failable
import Database.MySQL.Base       (MySQLConn)
import StrongSwan.SQL.Types

import qualified Database.MySQL.Base as SQL

prepareStatements :: (Failable m, MonadIO m, MonadFail m) => SQL.MySQLConn -> m PreparedStatements
prepareStatements conn = do
    let ?conn = conn
    [createChildSA, updateChildSA, findChildSA, deleteChildSA] <-
        initializeWith createChildSATableStatement
                       [createChildSAStatement,
                        updateChildSAStatement,
                        findChildSAStatement,
                        deleteChildSAStatement]

    findChildSAByName <- prepare findChildSAByNameStatement

    [createIKE, updateIKE, findIKE, deleteIKE] <-
        initializeWith createIKETableStatement
                       [createIKEStatement,
                        updateIKEStatement,
                        findIKEStatement,
                        deleteIKEStatement]

    [createPeer, updatePeer, findPeer, deletePeer] <-
        initializeWith createPeerTableStatement
                       [createPeerStatement,
                        updatePeerStatement,
                        findPeerStatement,
                        deletePeerStatement]

    findPeerByName <- prepare findPeerByNameStatement

    [createP2C, updateP2C, findP2C, deleteP2C] <-
        initializeWith createP2CTableStatement
                       [createP2CStatement,
                        updateP2CStatement,
                        findP2CStatement,
                        deleteP2CStatement]

    [createTS, updateTS, findTS, deleteTS] <-
        initializeWith createTSTableStatement
                       [createTSStatement,
                        updateTSStatement,
                        findTSStatement,
                        deleteTSStatement]

    [createC2TS, updateC2TS, findC2TS, deleteC2TS] <-
        initializeWith createC2TSTableStatement
                       [createC2TSStatement,
                        updateC2TSStatement,
                        findC2TSStatement,
                        deleteC2TSStatement]

    [createIPSec, findIPSec, deleteIPSec] <-
        initializeWith createIPSecTableStatement
                       [createIPSecStatement,
                        findIPSecStatement,
                        deleteIPSecStatement]

    return PreparedStatements  {
        updateChildSAStmt     = updateChildSA,
        createChildSAStmt     = createChildSA,
        findChildSAByNameStmt = findChildSAByName,
        findChildSAStmt       = findChildSA,
        deleteChildSAStmt     = deleteChildSA,
        updateIKEStmt         = updateIKE,
        createIKEStmt         = createIKE,
        findIKEStmt           = findIKE,
        deleteIKEStmt         = deleteIKE,
        updatePeerStmt        = updatePeer,
        createPeerStmt        = createPeer,
        findPeerStmt          = findPeer,
        findPeerByNameStmt    = findPeerByName,
        deletePeerStmt        = deletePeer,
        updateP2CStmt         = updateP2C,
        createP2CStmt         = createP2C,
        findP2CStmt           = findP2C,
        deleteP2CStmt         = deleteP2C,
        findTSStmt            = findTS,
        createTSStmt          = createTS,
        updateTSStmt          = updateTS,
        deleteTSStmt          = deleteTS,
        updateC2TSStmt        = updateC2TS,
        createC2TSStmt        = createC2TS,
        findC2TSStmt          = findC2TS,
        deleteC2TSStmt        = deleteC2TS,
        createIPSecStmt       = createIPSec,
        findIPSecStmt         = findIPSec,
        deleteIPSecStmt       = deleteIPSec
    }

prepare :: (?conn :: MySQLConn, Failable m, MonadIO m) => SQL.Query -> m SQL.StmtID
prepare = failableIO . SQL.prepareStmt ?conn

updateIKEStatement :: SQL.Query
updateIKEStatement = "UPDATE ike_configs SET certreq = ?, force_encap = ?, local = ?, remote = ? WHERE id = ?;"

createIKETableStatement :: SQL.Query
createIKETableStatement = "CREATE TABLE `ike_configs` ( `id` int(10) unsigned NOT NULL auto_increment, `certreq` tinyint(3) unsigned NOT NULL default '1', `force_encap` tinyint(1) NOT NULL default '0', `local` varchar(128) collate utf8_unicode_ci NOT NULL, `remote` varchar(128) collate utf8_unicode_ci NOT NULL, PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createIKEStatement :: SQL.Query
createIKEStatement = "INSERT INTO ike_configs (certreq, force_encap, local, remote) VAlUES (?, ?, ?, ?);"

findIKEStatement :: SQL.Query
findIKEStatement = "SELECT * FROM ike_configs WHERE id = ?;"

deleteIKEStatement :: SQL.Query
deleteIKEStatement = "DELETE FROM ike_configs WHERE id = ?;"

updateChildSAStatement :: SQL.Query
updateChildSAStatement = "UPDATE child_configs SET name = ?, lifetime = ?, rekeytime = ?, jitter = ?, updown = ?, hostaccess = ?, mode = ?, start_action = ?, dpd_action = ?, close_action = ?, ipcomp = ?, reqid = ? WHERE id = ?"

createChildSATableStatement :: SQL.Query
createChildSATableStatement = "CREATE TABLE `child_configs` ( `id` int(10) unsigned NOT NULL auto_increment, `name` varchar(32) collate utf8_unicode_ci NOT NULL, `lifetime` mediumint(8) unsigned NOT NULL default '1500', `rekeytime` mediumint(8) unsigned NOT NULL default '1200', `jitter` mediumint(8) unsigned NOT NULL default '60', `updown` varchar(128) collate utf8_unicode_ci default NULL, `hostaccess` tinyint(1) unsigned NOT NULL default '0', `mode` tinyint(4) unsigned NOT NULL default '2', `start_action` tinyint(4) unsigned NOT NULL default '0', `dpd_action` tinyint(4) unsigned NOT NULL default '0', `close_action` tinyint(4) unsigned NOT NULL default '0', `ipcomp` tinyint(4) unsigned NOT NULL default '0', `reqid` mediumint(8) unsigned NOT NULL default '0', PRIMARY KEY (`id`), INDEX (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createChildSAStatement :: SQL.Query
createChildSAStatement = "INSERT INTO child_configs (name, lifetime, rekeytime, jitter, updown, hostaccess, mode, start_action, dpd_action, close_action, ipcomp, reqid) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"

findChildSAByNameStatement :: SQL.Query
findChildSAByNameStatement = "SELECT * FROM child_configs WHERE name = ?;"

findChildSAStatement :: SQL.Query
findChildSAStatement = "SELECT * FROM child_configs WHERE id = ?;"

deleteChildSAStatement :: SQL.Query
deleteChildSAStatement = "DELETE FROM child_configs WHERE id = ?;"

createPeerTableStatement :: SQL.Query
createPeerTableStatement = "CREATE TABLE `peer_configs` ( `id` int(10) unsigned NOT NULL auto_increment, `name` varchar(32) collate utf8_unicode_ci NOT NULL, `ike_version` tinyint(3) unsigned NOT NULL default '2', `ike_cfg` int(10) unsigned NOT NULL, `local_id` varchar(64) collate utf8_unicode_ci NOT NULL, `remote_id` varchar(64) collate utf8_unicode_ci NOT NULL, `cert_policy` tinyint(3) unsigned NOT NULL default '1', `uniqueid` tinyint(3) unsigned NOT NULL default '0', `auth_method` tinyint(3) unsigned NOT NULL default '1', `eap_type` tinyint(3) unsigned NOT NULL default '0', `eap_vendor` smallint(5) unsigned NOT NULL default '0', `keyingtries` tinyint(3) unsigned NOT NULL default '3', `rekeytime` mediumint(8) unsigned NOT NULL default '7200', `reauthtime` mediumint(8) unsigned NOT NULL default '0', `jitter` mediumint(8) unsigned NOT NULL default '180', `overtime` mediumint(8) unsigned NOT NULL default '300', `mobike` tinyint(1) NOT NULL default '1', `dpd_delay` mediumint(8) unsigned NOT NULL default '120', `virtual` varchar(40) default NULL, `pool` varchar(32) default NULL, `mediation` tinyint(1) NOT NULL default '0', `mediated_by` int(10) unsigned NOT NULL default '0', `peer_id` int(10) unsigned NOT NULL default '0', PRIMARY KEY (`id`), INDEX (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createPeerStatement :: SQL.Query
createPeerStatement = "INSERT INTO peer_configs (`name`, `ike_version`, `ike_cfg`, `local_id`, `remote_id`, `cert_policy`, `uniqueid`, `auth_method`, `eap_type`, `eap_vendor`, `keyingtries`, `rekeytime`, `reauthtime`, `jitter`, `overtime`, `mobike`, `dpd_delay`, `virtual`, `pool`, `mediation`, `mediated_by`, `peer_id` ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"

updatePeerStatement :: SQL.Query
updatePeerStatement = "UPDATE peer_configs SET `name` = ?, `ike_version` = ?, `ike_cfg` = ?, `local_id` = ?, `remote_id` = ?, `cert_policy` = ?, `uniqueid` = ?, `auth_method` = ?, `eap_type` = ?, `eap_vendor` = ?, `keyingtries` = ?, `rekeytime` = ?, `reauthtime` = ?, `jitter` = ?, `overtime` = ?, `mobike` = ?, `dpd_delay` = ?, `virtual` = ?, `pool` = ?, `mediation` = ?, `mediated_by` = ?, `peer_id`  = ? WHERE `id` = ?;"

findPeerStatement :: SQL.Query
findPeerStatement = "SELECT * from peer_configs WHERE id = ?;"

findPeerByNameStatement :: SQL.Query
findPeerByNameStatement = "SELECT * from peer_configs WHERE name = ?;"

deletePeerStatement :: SQL.Query
deletePeerStatement = "DELETE from peer_configs WHERE id = ?;"

createP2CTableStatement :: SQL.Query
createP2CTableStatement = "CREATE TABLE `peer_config_child_config` (`peer_cfg` int(10) unsigned NOT NULL, `child_cfg` int(10) unsigned NOT NULL, PRIMARY KEY (`peer_cfg`, `child_cfg`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createP2CStatement :: SQL.Query
createP2CStatement = "INSERT INTO peer_config_child_config (peer_cfg, child_cfg) VALUES (?, ?);"

updateP2CStatement :: SQL.Query
updateP2CStatement = "UPDATE peer_config_child_config SET peer_cfg = ?, child_cfg = ? WHERE peer_cfg = ? AND child_cfg = ?;"

findP2CStatement :: SQL.Query
findP2CStatement = "SELECT * from peer_config_child_config WHERE peer_cfg = ? AND child_cfg = ?;"

deleteP2CStatement :: SQL.Query
deleteP2CStatement = "DELETE from peer_config_child_config WHERE peer_cfg = ? AND child_cfg = ?";

createTSTableStatement :: SQL.Query
createTSTableStatement = "CREATE TABLE `traffic_selectors` (`id` int(10) unsigned NOT NULL auto_increment, `type` tinyint(3) unsigned NOT NULL default '7', `protocol` smallint(5) unsigned NOT NULL default '0', `start_addr` varbinary(16) default NULL, `end_addr` varbinary(16) default NULL, `start_port` smallint(5) unsigned NOT NULL default '0', `end_port` smallint(5) unsigned NOT NULL default '65535', PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createTSStatement :: SQL.Query
createTSStatement = "INSERT INTO traffic_selectors (type, protocol, start_addr, end_addr, start_port, end_port) VALUES ( ?, ?, ?, ?, ?, ?);"

updateTSStatement :: SQL.Query
updateTSStatement = "UPDATE traffic_selectors SET type = ?, protocol = ?, start_addr = ?, end_addr = ?, start_port = ?, end_port = ? WHERE id = ?;"

findTSStatement :: SQL.Query
findTSStatement = "SELECT * FROM traffic_selectors WHERE id = ?;"

deleteTSStatement :: SQL.Query
deleteTSStatement = "DELETE FROM traffic_selectors WHERE id = ?";

createC2TSTableStatement :: SQL.Query
createC2TSTableStatement = "CREATE TABLE `child_config_traffic_selector` (`child_cfg` int(10) unsigned NOT NULL, `traffic_selector` int(10) unsigned NOT NULL, `kind` tinyint(3) unsigned NOT NULL, INDEX (`child_cfg`, `traffic_selector`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createC2TSStatement :: SQL.Query
createC2TSStatement = "INSERT INTO child_config_traffic_selector (child_cfg, traffic_selector, kind) VALUES (?, ?, ?);"

updateC2TSStatement :: SQL.Query
updateC2TSStatement = "UPDATE child_config_traffic_selector SET child_cfg = ?, traffic_selector = ?, kind = ? WHERE child_cfg = ? AND traffic_selector = ?;"

findC2TSStatement :: SQL.Query
findC2TSStatement = "SELECT * FROM child_config_traffic_selector WHERE child_cfg = ?;"

deleteC2TSStatement :: SQL.Query
deleteC2TSStatement = "DELETE FROM child_config_traffic_selector WHERE child_cfg = ?;"

createIPSecTableStatement :: SQL.Query
createIPSecTableStatement = "CREATE TABLE `ipsec_configs` (`name` varchar(64) NOT NULL, `child_cfg` int(10) unsigned NOT NULL, `peer_cfg` int(10) unsigned NOT NULL, `ike_cfg` int(10) unsigned NOT NULL, `local_ts` int(10) unsigned NOT NULL, `remote_ts` int(10) unsigned NOT NULL, PRIMARY KEY (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;"

createIPSecStatement :: SQL.Query
createIPSecStatement = "INSERT INTO ipsec_configs (name, child_cfg, peer_cfg, ike_cfg, local_ts, remote_ts) VALUES (?, ?, ?, ?, ?, ?);"

findIPSecStatement :: SQL.Query
findIPSecStatement = "SELECT * FROM ipsec_configs WHERE name = ?;"

deleteIPSecStatement :: SQL.Query
deleteIPSecStatement = "DELETE FROM ipsec_configs WHERE name = ?;"


initializeWith :: (Failable m, MonadIO m, ?conn::SQL.MySQLConn)
                => SQL.Query
                -> [SQL.Query]
                -> m [SQL.StmtID]
initializeWith createTableStatement operStatements = do
    createTable  <- prepare createTableStatement
    -- create config table before any other operations on it. If it already exists.. whatever
    (void . failableIO $ SQL.executeStmt ?conn createTable [])
                            `recover` \e ->
                                case fromException e of
                                  Just (SQL.ERRException SQL.ERR{..}) | errCode == 1050 ->
                                    -- ignore table already exists errors
                                    return ()
                                  _ ->
                                    failure e

    mapM prepare operStatements