{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE LambdaCase #-}
module BaseXClient.Session where

import BaseXClient.Utils
import BaseXClient.Query (Query(..))
import Control.Applicative
import qualified Data.Digest.Pure.MD5 as MD5
import Data.ByteString.Lazy.UTF8 (fromString)
import Data.List
import Network
import System.IO

data Result = Result {
    content :: String,
    info :: String
  }
  deriving Show

connect :: String -> PortNumber -> String -> String -> IO Handle
connect host port user pass = do
  session <- connectTo host $ PortNumber port
  hSetBuffering session $ BlockBuffering $ Just 4096
  resp <- readString session
  let (code, nonce) = case elemIndex ':' resp of
        Just i -> let (realm, ':' : nonce') = splitAt i resp in
          (intercalate ":" [user, realm, pass], nonce')
        Nothing -> (pass, nonce)
  writeStrings session [user, md5 $ md5 code ++ nonce]
  ok session <$$> \b -> if b
    then session
    else error "Access denied."
  where md5 = show . MD5.md5 . fromString

execute :: Handle -> String -> IO Result
execute session cmd = do
  writeString session cmd
  content <- readString session
  info <- readString session
  ok session <$$> \b -> if b
    then Result{content, info}
    else error info

query :: Handle -> String -> IO Query
query session q = do
  ident <- exec session 0 [q]
  return $ Query session ident

create, add, replace, store :: Handle -> String -> String -> IO String
create = sendInput 8
add = sendInput 9
replace = sendInput 12
store = sendInput 13

sendInput :: Int -> Handle -> String -> String -> IO String
sendInput code session arg input = exec session code [arg, input]

close :: Handle -> IO ()
close session = do
  writeString session "exit"
  hClose session