{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Data.Array.Accelerate.Control.Lens.Lift ( liftLens )
where
import Data.Array.Accelerate
#if defined(MIN_VERSION_accelerate)
#if !MIN_VERSION_accelerate(0,16,0)
instance Unlift Exp (Exp e) where
unlift = id
instance Unlift Acc (Acc a) where
unlift = id
#endif
#endif
liftLens
:: (Functor f, Unlift box s, Unlift box t, Unlift box b, Lift box a)
=> ((a -> f b) -> s -> f t)
-> (box (Plain a) -> f (box (Plain b)))
-> box (Plain s)
-> f (box (Plain t))
liftLens l f x = lift `fmap` l (fsink1 f) (unlift x)
fsink1 :: (Functor f, Unlift box b, Lift box a)
=> (box (Plain a) -> f (box (Plain b)))
-> a
-> f b
fsink1 f = fmap unlift . f . lift