From 7f9bd4d7d04f4c9c99c0b52077c145b8255ab029 Mon Sep 17 00:00:00 2001
From: Joey Adams <joeyadams3.14159@gmail.com>
Date: Mon, 12 Nov 2012 21:48:08 -0500
Subject: [PATCH] GHC.Windows: more error support (guards, system error
 strings)

---
 GHC/Windows.hs     | 149 ++++++++++++++++++++++++++++++++++++++++++++++++-----
 cbits/Win32Utils.c |  69 +++++++++++++++----------
 include/HsBase.h   |   1 +
 3 files changed, 180 insertions(+), 39 deletions(-)

diff --git a/GHC/Windows.hs b/GHC/Windows.hs
index fa25f63..fbcf97e 100644
--- a/GHC/Windows.hs
+++ b/GHC/Windows.hs
@@ -1,5 +1,7 @@
 {-# LANGUAGE Trustworthy #-}
-{-# LANGUAGE NoImplicitPrelude, ForeignFunctionInterface #-}
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE NoImplicitPrelude #-}
 -----------------------------------------------------------------------------
 -- |
 -- Module      :  GHC.Windows
@@ -19,30 +21,153 @@
 -----------------------------------------------------------------------------
 
 module GHC.Windows (
-        HANDLE, DWORD, LPTSTR, iNFINITE,
-        throwGetLastError, c_maperrno
-    ) where
+        -- * Types
+        BOOL,
+        DWORD,
+        ErrCode,
+        HANDLE,
+        LPWSTR,
+        LPTSTR,
 
-import GHC.Base
-import GHC.Ptr
+        -- * Constants
+        iNFINITE,
+        iNVALID_HANDLE_VALUE,
 
-import Data.Word
+        -- * System errors
+        throwGetLastError,
+        c_maperrno,
+        c_maperrno_func,
+        getErrorMessage,
+        getLastError,
+        errCodeToIOError,
+        failWith,
 
-import Foreign.C.Error (throwErrno)
+        -- ** Guards for system calls that might fail
+        failIf,
+        failIf_,
+        failIfNull,
+        failIfZero,
+        failIfFalse_,
+        failUnlessSuccess,
+        failUnlessSuccessOr,
+    ) where
+
+import Data.Char
+import Data.List
+import Data.Maybe
+import Data.Word
+import Foreign.C.Error
+import Foreign.C.String
 import Foreign.C.Types
+import Foreign.Ptr
+import GHC.Base
+import GHC.IO
+import GHC.Num
+import System.IO.Error
 
+import qualified Numeric
 
-type HANDLE       = Ptr ()
-type DWORD        = Word32
+#ifdef mingw32_HOST_OS
+# if defined(i386_HOST_ARCH)
+#  define WINDOWS_CCONV stdcall
+# elif defined(x86_64_HOST_ARCH)
+#  define WINDOWS_CCONV ccall
+# else
+#  error Unknown mingw32 arch
+# endif
+#endif
 
-type LPTSTR = Ptr CWchar
+type BOOL       = Bool
+type DWORD      = Word32
+type ErrCode    = DWORD
+type HANDLE     = Ptr ()
+type LPWSTR     = Ptr CWchar
+type LPTSTR     = LPWSTR
 
 iNFINITE :: DWORD
 iNFINITE = 0xFFFFFFFF -- urgh
 
+iNVALID_HANDLE_VALUE :: HANDLE
+iNVALID_HANDLE_VALUE = wordPtrToPtr (-1)
+
 throwGetLastError :: String -> IO a
-throwGetLastError where_from = c_maperrno >> throwErrno where_from
+throwGetLastError where_from =
+    getLastError >>= failWith where_from
 
 foreign import ccall unsafe "maperrno"             -- in Win32Utils.c
    c_maperrno :: IO ()
 
+foreign import ccall unsafe "maperrno_func"        -- in Win32Utils.c
+   c_maperrno_func :: ErrCode -> Errno
+
+foreign import ccall unsafe "base_getErrorMessage" -- in Win32Utils.c
+    c_getErrorMessage :: DWORD -> IO LPWSTR
+
+foreign import WINDOWS_CCONV unsafe "windows.h LocalFree"
+    localFree :: Ptr a -> IO (Ptr a)
+
+foreign import WINDOWS_CCONV unsafe "windows.h GetLastError"
+    getLastError :: IO ErrCode
+
+
+failIf :: (a -> Bool) -> String -> IO a -> IO a
+failIf p wh act = do
+    v <- act
+    if p v then throwGetLastError wh else return v
+
+failIf_ :: (a -> Bool) -> String -> IO a -> IO ()
+failIf_ p wh act = do
+    v <- act
+    if p v then throwGetLastError wh else return ()
+
+failIfNull :: String -> IO (Ptr a) -> IO (Ptr a)
+failIfNull = failIf (== nullPtr)
+
+failIfZero :: (Eq a, Num a) => String -> IO a -> IO a
+failIfZero = failIf (== 0)
+
+failIfFalse_ :: String -> IO Bool -> IO ()
+failIfFalse_ = failIf_ not
+
+failUnlessSuccess :: String -> IO ErrCode -> IO ()
+failUnlessSuccess fn_name act = do
+    r <- act
+    if r == 0 then return () else failWith fn_name r
+
+failUnlessSuccessOr :: ErrCode -> String -> IO ErrCode -> IO Bool
+failUnlessSuccessOr val fn_name act = do
+    r <- act
+    if r == 0 then return False
+        else if r == val then return True
+        else failWith fn_name r
+
+-- | Convert a Windows error code to an exception, then throw it.
+failWith :: String -> ErrCode -> IO a
+failWith fn_name err_code =
+    errCodeToIOError fn_name err_code >>= throwIO
+
+-- | Convert a Windows error code to an exception.
+errCodeToIOError :: String -> ErrCode -> IO IOError
+errCodeToIOError fn_name err_code = do
+    msg <- getErrorMessage err_code
+
+    -- turn GetLastError() into errno, which errnoToIOError knows
+    -- how to convert to an IOException we can throw.
+    -- XXX we should really do this directly.
+    let errno = c_maperrno_func err_code
+
+    let msg' = reverse $ dropWhile isSpace $ reverse msg -- drop trailing \n
+        ioerror = errnoToIOError fn_name errno Nothing Nothing
+                    `ioeSetErrorString` msg'
+    return ioerror
+
+getErrorMessage :: ErrCode -> IO String
+getErrorMessage err_code =
+    mask_ $ do
+        c_msg <- c_getErrorMessage err_code
+        if c_msg == nullPtr
+          then return $ "Error 0x" ++ Numeric.showHex err_code ""
+          else do msg <- peekCWString c_msg
+                  -- We ignore failure of freeing c_msg, given we're already failing
+                  _ <- localFree c_msg
+                  return msg
diff --git a/cbits/Win32Utils.c b/cbits/Win32Utils.c
index ecd54f3..7038cbf 100644
--- a/cbits/Win32Utils.c
+++ b/cbits/Win32Utils.c
@@ -80,34 +80,49 @@ static struct errentry errtable[] = {
 #define MIN_EACCES_RANGE ERROR_WRITE_PROTECT
 #define MAX_EACCES_RANGE ERROR_SHARING_BUFFER_EXCEEDED
 
-void maperrno (void)
+void maperrno(void)
 {
-	int i;
-	DWORD dwErrorCode;
-
-	dwErrorCode = GetLastError();
-
-	/* check the table for the OS error code */
-	for (i = 0; i < ERRTABLESIZE; ++i)
-	{
-		if (dwErrorCode == errtable[i].oscode)
-		{
-			errno = errtable[i].errnocode;
-			return;
-		}
-	}
-
-	/* The error code wasn't in the table.  We check for a range of */
-	/* EACCES errors or exec failure errors (ENOEXEC).  Otherwise   */
-	/* EINVAL is returned.                                          */
-
-	if (dwErrorCode >= MIN_EACCES_RANGE && dwErrorCode <= MAX_EACCES_RANGE)
-		errno = EACCES;
-	else
-		if (dwErrorCode >= MIN_EXEC_ERROR && dwErrorCode <= MAX_EXEC_ERROR)
-			errno = ENOEXEC;
-		else
-			errno = EINVAL;
+    errno = maperrno_func(GetLastError());
+}
+
+int maperrno_func(DWORD dwErrorCode)
+{
+    int i;
+
+    /* check the table for the OS error code */
+    for (i = 0; i < ERRTABLESIZE; ++i)
+        if (dwErrorCode == errtable[i].oscode)
+            return errtable[i].errnocode;
+
+    /* The error code wasn't in the table.  We check for a range of */
+    /* EACCES errors or exec failure errors (ENOEXEC).  Otherwise   */
+    /* EINVAL is returned.                                          */
+
+    if (dwErrorCode >= MIN_EACCES_RANGE && dwErrorCode <= MAX_EACCES_RANGE)
+        return EACCES;
+    else if (dwErrorCode >= MIN_EXEC_ERROR && dwErrorCode <= MAX_EXEC_ERROR)
+        return ENOEXEC;
+    else
+        return EINVAL;
+}
+
+LPWSTR base_getErrorMessage(DWORD err)
+{
+    LPWSTR what;
+    DWORD res;
+
+    res = FormatMessageW(
+              (FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER),
+              NULL,
+              err,
+              MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), /* Default language */
+              (LPWSTR) &what,
+              0,
+              NULL
+          );
+    if (res == 0)
+        return NULL;
+    return what;
 }
 
 int get_unique_file_info(int fd, HsWord64 *dev, HsWord64 *ino)
diff --git a/include/HsBase.h b/include/HsBase.h
index 74ab816..b1a62fd 100644
--- a/include/HsBase.h
+++ b/include/HsBase.h
@@ -141,6 +141,7 @@
 #if defined(__MINGW32__)
 /* in Win32Utils.c */
 extern void maperrno (void);
+extern int maperrno_func(DWORD dwErrorCode);
 extern HsWord64 getMonotonicUSec(void);
 #endif
 
-- 
1.8.0.msysgit.0

