module Happstack.State.Util
(
getRandom, getRandomR,
inferRecordUpdaters
) where
import Control.Concurrent.STM
import Control.Monad.State
import System.Random
import Happstack.State.Monad
import Happstack.State.Types
import Data.Char(toUpper)
import Language.Haskell.TH
getRandom :: Random a => AnyEv a
getRandom = do r <- sel evRandoms
g <- liftSTM $ readTVar r
let (x,g') = random g
liftSTM $ writeTVar r g'
return x
getRandomR :: Random a => (a,a) -> AnyEv a
getRandomR z = do r <- sel evRandoms
g <- liftSTM $ readTVar r
let (x,g') = randomR z g
liftSTM $ writeTVar r g'
return x
inferRecordUpdaters :: Name -> Q [Dec]
inferRecordUpdaters typeName = do
con <- decToSimpleRecord =<< nameToDec typeName
let c name upd s =
do let un = mkName ("a_"++ns)
wn = mkName ("with"++(toUpper (head ns):tail ns))
ns = nameBase name
ud <- un `sdef` upd
wd <- wn `sdef` (varE 'localState `appE` s `appE` varE un)
return [ud, wd]
xs <- sequence $ zipWith3 c (fieldNames con) (updFuns con) (selFuns con)
return $ concat xs
decToSimpleRecord :: Dec -> Q Con
decToSimpleRecord (DataD _ _ _ [con] _) = return con
decToSimpleRecord (DataD _ n _ _ _) =
fail ("Not a simple record (has multiple constructors): "++show n)
decToSimpleRecord (NewtypeD _ _ _ con _) = return con
decToSimpleRecord x = fail ("Wanted a simple record, got: "++show x)
nameToDec :: Name -> Q Dec
nameToDec ty = reify ty >>= un
where un (TyConI d) = return $ d
un _ = fail "nameToDec: expected TyCon"
selFuns :: Con -> [ExpQ]
selFuns (RecC _ ts) = [ varE n | (n,_,_) <- ts ]
selFuns _ = error "Constructors other than RecC not handled in selFuns"
updFuns :: Con -> [ExpQ]
updFuns (RecC _ ts) = [ upd n | (n,_,_) <- ts ]
where [x,y] = map mkName ["x","y"]
upd f = lamE [varP x, varP y] $ rup f
rup f = recUpdE (varE y) [return (f,VarE x)]
updFuns _ = error "Constructors other than RecC not handled in updFuns"
fieldNames :: Con -> [Name]
fieldNames (RecC _ ts) = [ n | (n,_,_) <- ts ]
fieldNames _ = error "Constructors other than RecC not handled in fieldNames"
sdef :: Name -> ExpQ -> DecQ
sdef vn ve = valD (varP vn) (normalB ve) []