{-# LANGUAGE RankNTypes #-}

module System.Zfs.Zpool where

import Control.Monad
import Control.Monad.IO.Class
import qualified System.Zfs.Lowlevel as L
import System.Zfs.Errors
import System.Zfs.Types
import Foreign.ForeignPtr
import Foreign.StablePtr
import Foreign.C.String
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import Foreign.Storable

-- | Gets a Zpool by name
getZpool :: String -> Zfs z (Zpool z)
getZpool str = Zfs $ \(ZfsContext z) -> liftIO $ do
  zph <- withCString str $ L.zpool_open z
  if zph == nullPtr
    then yieldZError z
    else do
    fptr <- newForeignPtr L.zpool_close_ zph
    return $ Right $ Zpool fptr

-- | Gets the name of some Zpool
getZpoolName :: Zpool z -> Zfs z String
getZpoolName (Zpool fptr) = Zfs $ \_ -> liftIO $ do
  cstr <- withForeignPtr fptr L.zpool_get_name
  str <- peekCString cstr
  return $ Right str

-- | Run some function for all available Zpools and return the results as a list
forZpools :: MonadIO m => (Zpool z -> ZfsT z IO a) -> ZfsT z m [a]
forZpools f = do
  count <- getZpoolCount
  let reservedSize = elemSize * count + indexSize
      elemSize = sizeOf (nullPtr :: Ptr ())
      indexSize = sizeOf (0 :: Int)
      fun z zpool' ptr = do
        zpool <- liftM Zpool $ newForeignPtr L.zpool_close_ zpool'
        res <- runZfs' (f zpool) z
        case res of
            Left e -> return 1 -- TODO: Return the right error code
            Right a -> do
              i <- peek (castPtr ptr)
              poke (castPtr ptr) (i+1 :: Int)
              sptr <- newStablePtr a
              let sptr' = castStablePtrToPtr sptr
              poke (plusPtr ptr (indexSize + elemSize * i)) sptr'
              return 0
  Zfs $ \(ZfsContext z) ->
    liftIO $ allocaBytes reservedSize $ \ptr -> do
      poke (castPtr ptr) (0 :: Int)
      fun' <- L.wrap_zpool_iter $ fun $ ZfsContext z
      L.zpool_iter z fun' ptr
      as <- peekArray count (plusPtr ptr indexSize)
      as' <- mapM (deRefStablePtr . castPtrToStablePtr) as
      mapM_ (freeStablePtr . castPtrToStablePtr) as
      return $ Right as'

-- | Run some function for all available Zpools without result. Less expensive than forZpools.
forZpools_ :: MonadIO m => (Zpool z -> ZfsT z IO ()) -> ZfsT z m ()
forZpools_ f = do
  count <- getZpoolCount
  let fun z zpool' ptr = do
        zpool <- liftM Zpool $ newForeignPtr L.zpool_close_ zpool'
        res <- runZfs' (f zpool) z
        case res of
            Left e -> return 1 -- TODO: Return the right error code
            Right _ -> return 0
  Zfs $ \(ZfsContext z) -> liftIO $ do
    fun' <- L.wrap_zpool_iter $ fun $ ZfsContext z
    L.zpool_iter z fun' nullPtr
    return $ Right ()

-- | Get the number of available Zpools. Less expensive than `length <$> getZpools'.
getZpoolCount :: Zfs z Int
getZpoolCount = Zfs $ \(ZfsContext z) ->
  liftIO $ alloca $ \ptr -> do
    poke ptr (0 :: Int)
    fun' <- L.wrap_zpool_iter fun
    L.zpool_iter z fun' ptr
    i <- peek ptr
    return $ Right i
  where fun zpool ptr = do
          i <- peek ptr
          poke ptr (i+1 :: Int)
          L.zpool_close zpool
          return 0

defaultFeatures = [
  "async_destroy",
  "empty_bpobj",
  "lz4_compress",
  "spacemap_histogram",
  "enabled_txg",
  "hole_birth",
  "extensible_dataset",
  "embedded_data",
  "bookmarks",
  "filesystem_limits",
  "large_blocks"]

mkVdev :: Vdev -> IO L.NvlistPtr
mkVdev vdev = do
  str_type <- newCString "type"
  str_file <- newCString "file"
  str_mirror <- newCString "mirror"
  str_children <- newCString "children"
  str_path <- newCString "path"
  str_is_log <- newCString "is_log"
  str_children <- newCString "children"
  let mkVdev' (VdFile fpath) = do
        str_the_path <- newCString fpath
        nvchild <- alloca $ \nvptr -> do
          L.nvlist_alloc nvptr 1 {- NV_UNIQUE_NAME -} 1 {- KM_SLEEP -}
          peek nvptr
        L.nvlist_add_string nvchild str_path str_the_path
        L.nvlist_add_string nvchild str_type str_file
        L.nvlist_add_uint64 nvchild str_is_log 0
        return nvchild
      mkVdev' (VdMirror vs) = do
        nvmirror <- alloca $ \nvptr -> do
          L.nvlist_alloc nvptr 1 {- NV_UNIQUE_NAME -} 1 {- KM_SLEEP -}
          peek nvptr
        L.nvlist_add_string nvmirror str_type str_mirror
        L.nvlist_add_uint64 nvmirror str_is_log 0
        children <- mapM mkVdev' vs
        child_arr <- newArray children
        L.nvlist_add_nvlist_array nvmirror str_children child_arr (fromIntegral $ length vs)
        return nvmirror
  mkVdev' vdev

-- | Create a zpool consisting of the given devices
createZpool :: String -> [Vdev] -> Zfs z ()
createZpool name vdevs = Zfs $ \(ZfsContext z) -> liftIO $ do
  nvptr <- alloca $ \nvptr -> do
    L.nvlist_alloc nvptr 1 {- NV_UNIQUE_NAME -} 1 {- KM_SLEEP -}
    peek nvptr
  str_root <- newCString "root"
  str_type <- newCString "type"
  str_children <- newCString "children"
  L.nvlist_add_string nvptr str_type str_root
  children <- mapM mkVdev vdevs
  child_arr <- newArray children
  props <- alloca $ \nvptr -> do
    L.nvlist_alloc nvptr 1 1
    peek nvptr
  str_enabled <- newCString "enabled"
  forM_ defaultFeatures $ \feat -> do
    str_feat <- newCString ("feature@"++feat)
    L.nvlist_add_string props str_feat str_enabled
  L.nvlist_add_nvlist_array nvptr str_children child_arr (fromIntegral $ length vdevs)
  res <- withCString name $ \cname -> L.zpool_create z cname nvptr props nullPtr
  if res == 0
    then return $ Right ()
    else yieldZError z

-- | Get all available Zpools. Implemented using `forZpools'.
getZpools :: Zfs z [Zpool z]
getZpools = forZpools return

-- | Get Zpool config
getZpoolConfig :: Zpool z -> Zfs z ZpoolConfig
getZpoolConfig (Zpool fptr) = Zfs $ \_ -> liftIO $ do
  nvl <- withForeignPtr fptr $ \ptr -> L.zpool_get_config ptr nullPtr
  name <- getstr nvl "name"
  hostname <- getstr nvl "hostname"
  version <- getint nvl "version"
  state <- getint nvl "state"
  guid <- getint nvl "pool_guid"
  return $ Right $ ZpoolConfig version name state guid hostname
  where getstr nvl key =
          withCString key $ \cstr ->
          alloca $ \ptr -> do
            L.nvlist_lookup_string nvl cstr ptr
            peekCString =<< peek ptr
        getint nvl key =
          withCString key $ \cstr ->
          alloca $ \ptr -> do
            L.nvlist_lookup_uint64 nvl cstr ptr
            fromIntegral <$> peek ptr

-- | Destroy a zpool. If second argument is True, unmounting is enforced.
destroyZpool :: Zpool z -> Bool -> Zfs z ()
destroyZpool (Zpool fptr) force = Zfs $ \(ZfsContext z) -> liftIO $ do
  r <- withForeignPtr fptr $ \ptr -> L.zpool_disable_datasets ptr force
  if r /= 0
    then yieldZError z
    else do
    r <- withForeignPtr fptr $ \ptr -> L.zpool_destroy ptr nullPtr
    if r /= 0
      then yieldZError z
      else return $ Right ()

-- | Get Vdevs. TODO: Whole disks, raidz
getVdevs :: Zpool z -> Zfs z [Vdev]
getVdevs (Zpool fptr) = Zfs $ \(ZfsContext z) -> liftIO $ do
  nvl <- withForeignPtr fptr $ \ptr -> L.zpool_get_config ptr nullPtr
  nvroot <- alloca $ \nvptr ->
    withCString "vdev_tree" $ \str_vdev_tree -> do
      L.nvlist_lookup_nvlist nvl str_vdev_tree nvptr
      peek nvptr
  let fetchchildren rt = do
        children <- alloca $ \nvptr ->
          alloca $ \szptr ->
          withCString "children" $ \str_children -> do
            L.nvlist_lookup_nvlist_array rt str_children nvptr szptr
            sz <- peek szptr
            arr <- peek nvptr
            as <- peekArray (fromIntegral sz) arr
            return as
        forM children $ \child ->
          withForeignPtr fptr $ \ptr -> do
            cstr <- L.zpool_vdev_name z ptr child True
            str <- peekCString cstr
            free cstr
            as <- case str of
              "mirror" -> VdMirror <$> fetchchildren child
              _ -> return $ VdFile str
            return as
  vdevs <- fetchchildren nvroot
  return $ Right vdevs

-- | Iterate over all Vdevs in a pool
forVdevs :: MonadIO m => Zpool z -> (Vdev -> ZfsT z m a) -> ZfsT z m [a]
forVdevs zpool f = do
  vdevs <- getVdevs zpool
  let f' (VdMirror vs) = concat <$> mapM f' vs
      f' v = return <$> f v
  concat <$> mapM f' vdevs