module LLVM.DSL.Example.Median where import qualified LLVM.DSL.Expression as Expr import LLVM.DSL.Expression (Exp, (==*), (<=*), (&&*), (||*)) import qualified LLVM.Extra.Multi.Vector as MVec import qualified LLVM.Extra.Multi.Value as MV import qualified LLVM.Core as LLVM import Data.Tuple.HT (uncurry3) import Data.Word (Word8) median3IfThen :: (MV.Comparison a) => Exp a -> Exp a -> Exp a -> Exp a median3IfThen a b c = Expr.ifThenElse (a<=*b) (Expr.ifThenElse (b<=*c) b (Expr.ifThenElse (a<=*c) c a)) (Expr.ifThenElse (a<=*c) a (Expr.ifThenElse (b<=*c) c b)) median3Select :: (MV.Comparison a, MV.Select a) => Exp a -> Exp a -> Exp a -> Exp a median3Select a b c = Expr.select (a<=*b) (Expr.select (b<=*c) b (Expr.select (a<=*c) c a)) (Expr.select (a<=*c) a (Expr.select (b<=*c) c b)) median3SelectShared :: (MV.Comparison a, MV.Select a) => Exp a -> Exp a -> Exp a -> Exp a median3SelectShared a b c = Expr.with (a<=*b) $ \a_le_b -> Expr.with (b<=*c) $ \b_le_c -> Expr.with (a<=*c) $ \a_le_c -> Expr.select a_le_b (Expr.select b_le_c b (Expr.select a_le_c c a)) (Expr.select a_le_c a (Expr.select b_le_c c b)) median3MinMax :: (MV.Comparison a, MV.Select a) => Exp a -> Exp a -> Exp a -> Exp a median3MinMax a b c = let minab = Expr.min a b in let maxab = Expr.max a b in Expr.select (maxab <=* c) maxab $ Expr.select (minab <=* c) c minab median3MinMaxVector :: (LLVM.Positive n, MVec.C a) => (MV.Comparison a, MV.Select a) => MVec.T n a -> MVec.T n a -> MVec.T n a -> LLVM.CodeGenFunction r (MVec.T n a) median3MinMaxVector a b c = MVec.map (uncurry3 (Expr.unliftM3 median3MinMax) . MV.unzip3) $ MVec.zip3 a b c type MV = MV.T median3Case :: (MV.Comparison a, MV.Select a) => MV a -> MV a -> MV a -> LLVM.CodeGenFunction r (MV a) median3Case a b c = do a_le_b <- MV.cmp LLVM.CmpLE a b a_le_c <- MV.cmp LLVM.CmpLE a c b_le_c <- MV.cmp LLVM.CmpLE b c let mask = MV.fromInteger' a_le_b_mask <- MV.select a_le_b (mask 1) (mask 0) a_le_c_mask <- MV.select a_le_c (mask 2) (mask 0) b_le_c_mask <- MV.select b_le_c (mask 4) (mask 0) maskMV <- MV.or a_le_b_mask =<< MV.or a_le_c_mask b_le_c_mask let maskE = Expr.lift0 (maskMV :: MV Word8) selectB <- Expr.unExp (maskE ==* 0 ||* maskE ==* 7) selectA <- Expr.unExp (maskE ==* 1 ||* maskE ==* 6) MV.select selectA a =<< MV.select selectB b c median3CaseVec :: (MV.Comparison a, MV.Select a) => MV a -> MV a -> MV a -> LLVM.CodeGenFunction r (MV a) median3CaseVec a b c = do a_le_b <- MV.cmp LLVM.CmpLE a b a_le_c <- MV.cmp LLVM.CmpLE a c b_le_c <- MV.cmp LLVM.CmpLE b c let check ab ac bc = Expr.select (Expr.lift0 a_le_b) 1 0 ==* (ab :: Exp Word8) &&* Expr.select (Expr.lift0 a_le_c) 1 0 ==* (ac :: Exp Word8) &&* Expr.select (Expr.lift0 b_le_c) 1 0 ==* (bc :: Exp Word8) selectB <- Expr.unExp (check 0 0 0 ||* check 1 1 1) selectA <- Expr.unExp (check 1 0 0 ||* check 0 1 1) MV.select selectA a =<< MV.select selectB b c