module System.Win32.Com.Dll
(
ComponentInfo(..)
, mkComponentInfo
, withComponentName
, withProgID
, withVerIndepProgID
, onRegister
, onFinalize
, hasTypeLib
, createIComDll
, regAddEntry
, regRemoveEntry
, RegHive(..)
, stdRegComponent
, stdUnRegComponent
, ComponentFactory
) where
import System.Win32.Com.ClassFactory
import System.Win32.Com
import System.Win32.Com.Exception
import System.Win32.Com.Server
import System.Win32.Com.Base ( getModuleFileName )
import System.Win32.Com.HDirect.HDirect ( Ptr, marshallString )
import Foreign hiding ( Ptr )
import Data.Word ( Word32 )
import System.Win32.Com.HDirect.HDirect ( marshallMaybe )
import Data.IORef ( IORef, newIORef, readIORef, writeIORef )
import Control.Exception
import Control.Monad
import Data.List ( find )
data ComponentInfo
= ComponentInfo
{ newInstance :: ComponentFactory
, componentFinalise :: IO ()
, componentName :: String
, componentProgID :: String
, componentVProgID :: String
, componentTLB :: Bool
, registerComponent :: ComponentInfo -> String -> Bool -> IO ()
, componentCLSID :: CLSID
}
type ComponentFactory
= String
-> IO ()
-> IID (IUnknown ())
-> IO (IUnknown ())
withProgID :: String -> ComponentInfo -> ComponentInfo
withProgID p info = info{componentProgID=p}
onRegister :: (ComponentInfo -> String -> Bool -> IO ()) -> ComponentInfo -> ComponentInfo
onRegister reg info =
info{registerComponent= \ a b c -> reg a b c >> (registerComponent info) a b c}
onFinalize :: IO () -> ComponentInfo -> ComponentInfo
onFinalize act info = info{componentFinalise= act >> (componentFinalise info)}
withVerIndepProgID :: String -> ComponentInfo -> ComponentInfo
withVerIndepProgID p info = info{componentVProgID=p}
withFinaliser :: IO () -> ComponentInfo -> ComponentInfo
withFinaliser act info = info{componentFinalise=act}
withComponentName :: String -> ComponentInfo -> ComponentInfo
withComponentName n info = info{componentName=n}
hasTypeLib :: ComponentInfo -> ComponentInfo
hasTypeLib info = info{componentTLB=True}
mkComponentInfo :: CLSID
-> (String -> Bool -> IO ())
-> (String -> IO () -> IID (IUnknown ()) -> IO (IUnknown ()))
-> ComponentInfo
mkComponentInfo cls reg n = ComponentInfo n (return ()) "" "" "" False (\ _ -> reg) cls
data ComDllState
= ComDllState {
dllPath :: String,
components :: IORef [ComponentInfo],
lockCount :: IORef Int
}
dllGetClassObject :: ComDllState -> Ptr CLSID -> Ptr (IID a) -> Ptr (Ptr (IUnknown a)) -> IO HRESULT
dllGetClassObject comDll rclsid riid ppvObject = do
iid <- unmarshallIID False (castPtr riid)
let g = iidToGUID iid
if ( not (g == iidToGUID iidIClassFactory || g == iidToGUID iidIUnknown) ) then
return e_NOINTERFACE
else do
clsid <- unmarshallCLSID False rclsid
cs <- readIORef (components comDll)
case lookupCLSID clsid cs of
Nothing -> return cLASS_E_CLASSNOTAVAILABLE
Just i -> do
ip <- createClassFactory (newInstance i (dllPath comDll) (componentFinalise i))
writeIUnknown False ppvObject ip
return s_OK
lookupCLSID :: CLSID -> [ComponentInfo] -> Maybe ComponentInfo
lookupCLSID clsid cs = find (\ x -> clsidToGUID (componentCLSID x) == guid) cs
where
guid = clsidToGUID clsid
dllCanUnloadNow :: ComDllState -> IO HRESULT
dllCanUnloadNow state = do
c <- readIORef (lockCount state)
if c == 0 then
return s_OK
else
return s_FALSE
dllRegisterServer :: ComDllState -> IO HRESULT
dllRegisterServer = registerServer True
dllUnregisterServer :: ComDllState -> IO HRESULT
dllUnregisterServer = registerServer False
registerServer :: Bool -> ComDllState -> IO HRESULT
registerServer isReg st = do
cs <- readIORef (components st)
let
path = dllPath st
regComponent info
| not isReg = do
(registerComponent info) info path isReg
stdUnRegComponent info True path
| otherwise = do
stdRegComponent info True path
(registerComponent info) info path isReg
mapM_ regComponent cs
case s_OK of
14 -> return s_OK
x -> return x
dllUnload :: ComDllState -> IO ()
dllUnload st = return ()
newComDllState :: Ptr () -> [ComponentInfo] -> IO ComDllState
newComDllState hMod cs = do
path <- getModuleFileName hMod
ref_cs <- newIORef cs
lc <- newIORef 1
return (ComDllState path ref_cs lc)
createIComDll :: Ptr () -> [ComponentInfo] -> IO (VTable iid_comDllState ComDllState)
createIComDll hMod components = do
state <- newComDllState hMod components
meths <- iComDllEntryPoints state
createVTable meths
iComDllEntryPoints :: ComDllState -> IO [Ptr ()]
iComDllEntryPoints state = do
addrOf_DllUnload <- export_DllUnload (dllUnload state)
addrOf_DllCanUnloadNow <- export_nullaryMeth (dllCanUnloadNow state)
addrOf_DllRegisterServer <- export_nullaryMeth (dllRegisterServer state)
addrOf_DllUnregisterServer <- export_nullaryMeth (dllUnregisterServer state)
addrOf_DllGetClassObject <- export_dllGetClassObject (dllGetClassObject state)
return [ castPtr addrOf_DllUnload
, castPtr addrOf_DllCanUnloadNow
, castPtr addrOf_DllRegisterServer
, castPtr addrOf_DllUnregisterServer
, castPtr addrOf_DllGetClassObject
]
foreign import ccall "wrapper"
export_DllUnload :: IO () -> IO (Ptr (IO ()))
foreign import ccall "wrapper"
export_nullaryMeth :: IO HRESULT -> IO (Ptr (IO HRESULT))
foreign import ccall "wrapper"
export_dllGetClassObject :: (Ptr CLSID -> Ptr (IID a) -> Ptr (Ptr (IUnknown a)) -> IO HRESULT)
-> IO (Ptr (Ptr CLSID -> Ptr (IID a) -> Ptr (Ptr (IUnknown a)) -> IO HRESULT))
data RegHive
= HKEY_CLASSES_ROOT
| HKEY_CURRENT_USER
| HKEY_LOCAL_MACHINE
| HKEY_USERS
| HKEY_CURRENT_CONFIG
deriving ( Eq, Ord, Enum )
regAddEntry :: RegHive
-> String
-> Maybe String
-> IO ()
regAddEntry hive path value = do
m_path <- marshallString path
m_value <- marshallMaybe marshallString nullPtr value
hr <- primRegAddEntry (fromEnum hive) m_path m_value
checkHR hr
regRemoveEntry :: RegHive
-> String
-> String
-> Bool
-> IO ()
regRemoveEntry hive path value removeKey = do
m_path <- marshallString path
m_value <- marshallString value
let m_removeKey
| removeKey = (1::Int)
| otherwise = 0
hr <- primRegRemoveEntry (fromEnum hive) m_path m_value m_removeKey
checkHR hr
foreign import ccall "primRegAddEntry"
primRegAddEntry :: Int -> Ptr String -> Ptr String -> IO HRESULT
foreign import ccall "primRegRemoveEntry"
primRegRemoveEntry :: Int -> Ptr String -> Ptr String -> Int -> IO HRESULT
stdRegComponent :: ComponentInfo -> Bool -> String -> IO ()
stdRegComponent info isInProc path = do
let clsid_path = "CLSID\\" ++ clsid_str
progid = componentProgID info
vprogid = componentVProgID info
clsid_str = show (componentCLSID info)
regAddEntry HKEY_CLASSES_ROOT clsid_path (Just (componentName info))
when (not (null progid)) (regAddEntry HKEY_CLASSES_ROOT (clsid_path++"\\ProgID") (Just progid))
when (not (null vprogid)) (regAddEntry HKEY_CLASSES_ROOT (clsid_path++"\\VersionIndependentProgID") (Just vprogid))
regAddEntry HKEY_CLASSES_ROOT (clsid_path ++ (if isInProc then "\\InprocServer32" else "\\LocalServer32")) (Just path)
when (componentTLB info)
(catch (loadTypeLibEx path True >>= \ p -> p # release >> return ())
(\ e -> do p <- return $ show (e :: SomeException)
return ()))
when (not (null progid)) (regAddEntry HKEY_CLASSES_ROOT (progid ++ "\\CLSID") (Just clsid_str))
when (not (null vprogid)) (regAddEntry HKEY_CLASSES_ROOT (vprogid ++ "\\CLSID") (Just clsid_str))
when (not (null vprogid) && not (null progid))
(regAddEntry HKEY_CLASSES_ROOT (progid ++ "\\CurVer") (Just vprogid))
return ()
stdUnRegComponent :: ComponentInfo -> Bool -> String -> IO ()
stdUnRegComponent info isInProc path = do
let clsid_path = "CLSID\\" ++ clsid_str
progid = componentProgID info
vprogid = componentVProgID info
clsid_str = show (componentCLSID info)
regRemoveEntry HKEY_CLASSES_ROOT clsid_path (if isInProc then "InprocServer32" else "LocalServer32") True
when (not (null vprogid)) (regRemoveEntry HKEY_CLASSES_ROOT clsid_path "VersionIndependentProgID" True)
when (not (null progid)) (regRemoveEntry HKEY_CLASSES_ROOT clsid_path "ProgID" True)
regRemoveEntry HKEY_CLASSES_ROOT "CLSID" clsid_str True
when (not (null progid)) (regRemoveEntry HKEY_CLASSES_ROOT (progid ++ "\\CLSID") clsid_str False)
when (not (null progid)) (regRemoveEntry HKEY_CLASSES_ROOT progid "CLSID" True)
when (not (null progid) && not (null vprogid))
(regRemoveEntry HKEY_CLASSES_ROOT progid "CurVer" True)
when (not (null vprogid)) (regRemoveEntry HKEY_CLASSES_ROOT (vprogid ++ "\\CLSID") clsid_str False)
when (not (null vprogid)) (regRemoveEntry HKEY_CLASSES_ROOT vprogid "CLSID" True)
return ()