{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, TypeSynonymInstances #-}

-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- Access to the request database.
-- This module provides transparent access to the information stored in the
-- database, and ensures that it is read in a sensible way.
--
-----------------------------------------------------------------------------

module Database.Data
   ( SqlRecord
   , SqlTable(..)
   , requestsTable
   , studentsTable
   , (!)
   , findWithDefault
   , insert
   , adjust
   , quickQueryMap
   , tableExists
   , createTable
   , dropTable
   , insertRecord
   , insertRecords
   , SQL.commit
   , PilotID
   , StudentID
   , TaskID
   , NodeID
   , InputID
   , StudentModel
   , collectSource
   , collectInputs
   , collectIDs
   , collectSolution
   , defaultEntry
   , recordN
   , allEvidence
   , allEvidenceStudents
   , allPartialModels
   , allStudentModels
   , allHumanAssessments
   , tasksDone
   , latestRecord
   , latestEvidence
   , latestPartialModels
   , latestStudentModel
   , latestStudentReport
   , latestInsertRowID
   , records2table
   , allRecords
   , finalRecords
   , taskRecords
   , studentRecords
   , countStudents
   , countEntries
   , countStudentEntries
   , countTaskEntries ) where

import Data.Convertible.Base (Convertible, safeConvert, convError)
import Control.Monad ( when, void, (>=>) )
import Database.HDBC (SqlValue, fromSql, toSql)
import Database.HDBC.Sqlite3 (Connection)
import Ideas.Utils.Prelude (readM)
import qualified Data.Map as M
import qualified Ideas.Text.XML as XML
import qualified Database.HDBC as SQL

import Util.String (wrap)
import Bayes.Evidence ( Evidence )
import Bayes.StudentReport (StudentReport, toReport)
import Recognize.Data.Solution ( Solution )

-------------------------------------------------------------------------------
-- * Data conversion


instance Convertible SqlValue XML.XML where
   safeConvert sql = either (flip convError sql) Right . XML.parseXML . fromSql $ sql

instance Convertible XML.XML SqlValue where
   safeConvert xml = Right . toSql . XML.compactXML $ xml

instance Convertible SqlValue Evidence where
   safeConvert = Right . maybe mempty id . readM . fromSql

instance Convertible Evidence SqlValue where
   safeConvert = Right . toSql . show

-- | Extract the source from an XML request element.
collectSource :: Monad m => XML.XML -> m String
collectSource xml = XML.findAttribute "source" xml

-- | Extract the input elements from an XML request element.
collectInputs :: Monad m => XML.XML -> m [(InputID, XML.XML)]
collectInputs xml = do
   inputXMLs <- XML.findChildren "input" <$> XML.findChild "solution" xml
   inputIDs <- mapM (XML.findAttribute "id") inputXMLs
   return $ zip inputIDs inputXMLs


-- | Extract the student and task ID from an XML request element.
collectIDs :: Monad m => XML.XML -> m (StudentID, TaskID)
collectIDs xml = do
   taskID' <- XML.findAttribute "exerciseid" xml
   studentID' <- XML.findChild "solution" xml >>= XML.findAttribute "userid"
   return (studentID', taskID')


collectSolution :: Monad m => XML.XML -> m Solution
collectSolution = XML.findChild "solution" >=> XML.fromXML

-------------------------------------------------------------------------------
-- * Generic SQL

-- | SQL records are read into a map.
type SqlRecord = M.Map String SqlValue

-- | An 'SqlTable' collects information that will be useful in constructing SQL
-- statements.
data SqlTable = SqlTable
   { tableName :: String
   , columns :: [(String, String)]
   }


-- | Obtain the value of a record at a certain column. Differs from `M.!` in
-- that the value will be converted from an 'SqlValue'.
(!) :: Convertible SqlValue a => SqlRecord -> String -> a
record ! key = fromSql $ record M.! key

-- | Obtain the value of a record at a certain column, with a given default.
-- Differs from `M.findWithDefault` in that the value will be converted from an
-- 'SqlValue'.
findWithDefault :: Convertible SqlValue a => a -> String -> SqlRecord -> a
findWithDefault def key = maybe def fromSql . M.lookup key


-- | Insert a value of a record into a certain column.
insert :: Convertible a SqlValue => String -> a -> SqlRecord -> SqlRecord
insert key value record = M.insert key (toSql value) record


-- | Update the value of a record at a certain column.
adjust :: (Convertible SqlValue a, Convertible b SqlValue)
       => String -> (a -> b) -> SqlRecord -> SqlRecord
adjust key f record = M.adjust (toSql . f . fromSql) key record


-- | Like 'Database.HDBC.quickQuery', but obtains 'SqlRecords' instead of lists.
quickQueryMap :: Connection -> String -> [SqlValue] -> IO [SqlRecord]
quickQueryMap conn statement values = do
   stm <- SQL.prepare conn statement
   void $ SQL.execute stm values
   SQL.fetchAllRowsMap stm


-- | Obtain latest row insertion ID of a particular table.
latestInsertRowID :: Connection -> String -> IO (Maybe Int)
latestInsertRowID conn table = do
   stm <- SQL.prepare conn $ "SELECT last_insert_rowid() FROM "++ table ++" LIMIT 1"
   void $ SQL.execute stm []
   fmap (fromSql . head) <$> SQL.fetchRow stm


-- | Test if the given table exists.
tableExists :: Connection -> String -> IO Bool
tableExists conn name = (name `elem`) <$> SQL.getTables conn


-- | Create the given table if it does not already exist.
createTable :: Connection -> SqlTable -> IO ()
createTable conn table = do
   tables <- SQL.getTables conn
   when (not $ tableName table `elem` tables) $ do
      putStrLn $ "Creating `" ++ tableName table ++ "` table…"
      void $ SQL.run conn statement []

   where
   statement =
      "CREATE TABLE " ++ tableName table ++
      wrap " (" ", " ")" (map (\(x,y) -> "`" ++ x ++ "` " ++ y) $ columns table)


-- | Drop the given table unless it does not exist.
dropTable :: Connection -> SqlTable -> IO ()
dropTable conn table = do
   tables <- SQL.getTables conn
   when (tableName table `elem` tables) $ do
      putStrLn $ "Dropping `" ++ tableName table ++ "` table…"
      void $ SQL.run conn statement []

   where
   statement = "DROP TABLE " ++ tableName table


-- | Insert the given record into a table. Assumes that the record contains at
-- least all the columns of the table.
insertRecord :: Connection -> SqlTable -> SqlRecord -> IO ()
insertRecord conn table record = do
   --putStrLn $ record ! "input"
   void . SQL.run conn statement . entry2sql $ record

   where
   statement =
      "INSERT INTO " ++ tableName table ++
      " VALUES " ++ wrap "(" "," ")" (length (columns table) `take` repeat "?")

   entry2sql :: SqlRecord -> [SqlValue]
   entry2sql record = map ((record M.!) . fst) (columns table)


-- | Insert the given records into a table. Assumes that the record contains at
-- least all the columns of the table.
insertRecords :: Connection -> SqlTable -> [SqlRecord] -> IO ()
insertRecords conn table records = do
   putStrLn $  "Inserting " ++ show (length records)
            ++ " records into `" ++ tableName table ++ "`…"
   mapM_ (insertRecord conn table) records


-------------------------------------------------------------------------------
-- * Database-specific

-- ** SQL tables

type StudentModel = Evidence
type PilotID = String
type StudentID = String
type TaskID = String
type InputID = String
type NodeID = String

requestsTable :: SqlTable
requestsTable = SqlTable
   { tableName = "requests"
   , columns =
      [ ("service", "TEXT")
      , ("exerciseid", "TEXT")
      , ("source", "TEXT")
      , ("script", "TEXT")
      , ("requestinfo", "TEXT")
      , ("dataformat", "TEXT")
      , ("encoding", "TEXT")
      , ("userid", "TEXT")
      , ("sessionid", "TEXT")
      , ("taskid", "TEXT")
      , ("time", "TIME")
      , ("responsetime", "TIME")
      , ("ipaddress", "TEXT")
      , ("binary", "TEXT")
      , ("version", "TEXT")
      , ("errormsg", "TEXT")
      , ("serviceinfo", "TEXT")
      , ("ruleid", "TEXT")
      , ("input", "TEXT")
      , ("output", "TEXT")
      ]
   }

studentsTable :: SqlTable
studentsTable = SqlTable
   { tableName = "students"
   , columns =
      [ ("original", "TEXT")
      , ("studentid", "TEXT")
      , ("taskid", "TEXT")
      , ("inputs", "TEXT")
      , ("requestnr", "INTEGER")
      , ("evidence", "TEXT")
      , ("partialmodel", "TEXT")
      , ("studentmodel", "TEXT")
      ]
   }


-- | Make a default entry, where all columns are set to null.
defaultEntry :: [SqlTable] -> SqlRecord
defaultEntry = M.fromList . map (fmap $ const SQL.SqlNull) . (>>= columns)


-- ** Querying

-- | Obtain all records from the 'requests' table, optionally combined with the
-- 'students' table, if the latter is present.
allRecords :: Connection -> IO [SqlRecord]
allRecords conn = do
   tables <- SQL.getTables conn
   if "requests" `elem` tables
      then quickQueryMap conn (stm tables) []
      else error "The `requests` table is not present."

   where
   stm tables = if "students" `elem` tables
      then "SELECT students.ROWID,* FROM requests \
           \INNER JOIN students \
           \ON students.requestnr = requests.ROWID"
      else "SELECT requests.ROWID,* FROM requests"


-- | Obtain all entries associated with a particular task.
taskRecords :: Connection -> String -> IO [SqlRecord]
taskRecords conn tID = quickQueryMap conn stm [toSql tID]

   where
   stm = "SELECT students.ROWID,* FROM requests \
      \ INNER JOIN students \
      \ ON students.requestnr = requests.ROWID \
      \ WHERE students.taskid=? \
      \ ORDER BY students.ROWID"


-- | Obtain all entries associated with a particular student.
studentRecords :: Connection -> String -> IO [SqlRecord]
studentRecords conn sID = quickQueryMap conn stm [toSql sID]

   where
   stm = "SELECT students.ROWID,* FROM requests \
      \ INNER JOIN students \
      \ ON students.requestnr = requests.ROWID \
      \ WHERE studentid=? \
      \ ORDER BY students.ROWID"


-- | Obtain the record at a certain row.
recordN :: Connection -> Int -> IO (Maybe SqlRecord)
recordN conn i = do
   stm' <- SQL.prepare conn stm
   void $ SQL.execute stm' [toSql i]
   SQL.fetchRowMap stm'
   where
   stm = "SELECT requests.ROWID,* FROM requests INNER JOIN students ON students.requestnr=requests.ROWID WHERE requests.ROWID = ?"


-- | Obtain the final entry for every student.
finalRecords :: Connection -> IO [SqlRecord]
finalRecords conn = quickQueryMap conn stm []

   where
   stm = "SELECT MAX(students.ROWID) as ROWID,* FROM requests \
      \ INNER JOIN students \
      \ ON students.requestnr = requests.ROWID \
      \ GROUP BY students.studentid \
      \ ORDER BY students.ROWID"


-- | Turn records into a table of evidence.
records2table :: [SqlRecord] -> M.Map (StudentID, TaskID) Evidence
records2table =
   let f r = M.insertWith mappend (r ! "studentid", r ! "taskid") (r ! "evidence")
   in foldl (flip f) mempty


-- | Get the evidence for every student/task pair.
allEvidence :: Connection -> IO (M.Map (StudentID, TaskID) Evidence)
allEvidence conn = records2table <$> allRecords conn


-- | Get the evidence for every student.
allEvidenceStudents :: Connection -> IO (M.Map StudentID Evidence)
allEvidenceStudents conn = M.mapKeysWith mappend fst <$> allEvidence conn


-- | Get the most recent partial models for every task, for every student.
allPartialModels :: Connection -> IO (M.Map StudentID [(TaskID, Evidence)])
allPartialModels conn = do
   records <- quickQueryMap conn stm []
   return . M.fromListWith (++) . map (\r -> (r ! "studentid", [(r!"taskid",r!"partialmodel")])) $ records

   where
   stm = "SELECT MAX(ROWID),studentid,taskid,partialmodel FROM students GROUP BY studentid,taskid"


-- | Get all human assessments, if any exist in the database
allHumanAssessments :: Connection -> IO (M.Map (StudentID, TaskID) [(NodeID, Maybe String, Maybe String)])
allHumanAssessments conn = do
   b <- tableExists conn "assessment"
   if not b
      then return mempty
      else M.fromListWith (++) . map (\[sID, tID, nID, exp', obs] -> ((fromSql sID, fromSql tID), [(fromSql nID, fromSql exp', fromSql obs)])) <$> SQL.quickQuery conn stm []

   where
   stm = "SELECT studentid,taskid,nodeid,expected,observed FROM assessment"


-- | Get the most recent student model for every student.
allStudentModels :: Connection -> IO [(StudentID, StudentModel)]
allStudentModels conn =
   map (\r -> (r ! "studentid", r ! "studentmodel")) <$> quickQueryMap conn stm []

   where
   stm = "SELECT MAX(ROWID),studentid,studentmodel FROM students GROUP BY studentid"


-- | Get the latest record added to the database.
latestRecord :: Connection -> IO (Maybe SqlRecord)
latestRecord conn = do
   stm <- SQL.prepare conn "SELECT ROWID,* FROM requests ORDER BY rowid DESC LIMIT 1"
   void $ SQL.execute stm []
   SQL.fetchRowMap stm


-- | Get most recent student model for a particular student.
latestStudentModel :: Connection -> StudentID -> IO StudentModel
latestStudentModel conn sID = do
   stm <- SQL.prepare conn "SELECT studentmodel FROM students \
                           \WHERE studentid = ? ORDER BY requestnr DESC"
   void $ SQL.execute stm [toSql sID]
   maybe mempty (fromSql . head) <$> SQL.fetchRow stm


-- | Get student report for a particular student.
latestStudentReport :: Connection -> String -> StudentID -> IO StudentReport
latestStudentReport conn lang sID = latestStudentModel conn sID >>= toReport sID lang


-- | Get all evidence collected for a particular student.
latestEvidence :: Connection -> StudentID -> IO Evidence
latestEvidence conn sID =
   mconcat . map (fromSql . head) <$> SQL.quickQuery conn stm [toSql sID]
   where
   stm = "SELECT evidence FROM students WHERE studentid = ?"


-- | Get all partial (calculated!) evidence models for a particular student.
latestPartialModels :: Connection -> StudentID -> IO [(TaskID, Evidence)]
latestPartialModels conn sID =
   map (\[_,tID,pm] -> (fromSql tID, fromSql pm)) <$> SQL.quickQuery conn stm [toSql sID]
   where
   stm = "SELECT MAX(ROWID),taskid,partialmodel FROM students WHERE studentid = ? GROUP BY taskid"


-- | Get all Tasks a student has finished and all Tasks they have yet to
-- finish.
tasksDone :: Connection -> StudentID -> IO [TaskID]
tasksDone conn sID = do
   map (fromSql . head) <$> SQL.quickQuery conn stm [toSql sID]
   where
   stm = "SELECT taskid FROM students WHERE studentid = ? GROUP BY taskid"


-- | Obtain the number of requests for each student.
countStudentEntries :: Connection -> IO [(String, Int)]
countStudentEntries conn =
   map (\[sID, n] -> (fromSql sID, fromSql n)) <$> SQL.quickQuery conn stm []

   where
   stm = "SELECT studentid,count(*) FROM students GROUP BY studentid"


-- | Obtain the number of requests for each task.
countTaskEntries :: Connection -> IO [(String, Int)]
countTaskEntries conn =
   map (\[tID, n] -> (fromSql tID, fromSql n)) <$> SQL.quickQuery conn stm []

   where
   stm = "SELECT taskid,count(*) FROM students GROUP BY taskid"


-- | Obtain the number of entries for each student/task pair. Should be 0 or 1...
countEntries :: Connection -> IO [((String, String), Int)]
countEntries conn =
   SQL.quickQuery conn stm [] >>=
      mapM (\[sID, tID, n] -> return ((fromSql sID, fromSql tID), fromSql n))

   where
   stm = "SELECT studentid,taskid,count(*) FROM students \
         \GROUP BY studentid, taskid"


-- | Obtain the total number of students.
countStudents :: Connection -> IO Int
countStudents conn = do
   stm' <- SQL.prepare conn stm
   void $ SQL.execute stm' []
   maybe 0 (fromSql . head) <$> SQL.fetchRow stm'

   where
   stm = "SELECT count(*) FROM \
         \   (SELECT count(*) FROM students GROUP BY studentid)"