{-# LANGUAGE LambdaCase #-}
module Opaleye.Internal.Optimize where
import Prelude hiding (product)
import qualified Opaleye.Internal.PrimQuery as PQ
import Opaleye.Internal.Helpers ((.:))
import qualified Data.List.NonEmpty as NEL
import Data.Semigroup ((<>))
import Control.Applicative ((<$>), (<*>), pure)
import Control.Arrow (first)
optimize :: PQ.PrimQuery' a -> PQ.PrimQuery' a
optimize = mergeProduct . removeUnit
removeUnit :: PQ.PrimQuery' a -> PQ.PrimQuery' a
removeUnit = PQ.foldPrimQuery PQ.primQueryFoldDefault { PQ.product = product }
where product pqs = PQ.Product pqs'
where pqs' = case NEL.nonEmpty (NEL.filter (not . PQ.isUnit . snd) pqs) of
Nothing -> return (pure PQ.Unit)
Just xs -> xs
mergeProduct :: PQ.PrimQuery' a -> PQ.PrimQuery' a
mergeProduct = PQ.foldPrimQuery PQ.primQueryFoldDefault { PQ.product = product }
where product pqs pes = PQ.Product pqs' (pes ++ pes')
where pqs' = pqs >>= queries
queries (lat, PQ.Product qs _) = fmap (first (lat <>)) qs
queries q = return q
pes' = NEL.toList pqs >>= conds
conds (_lat, PQ.Product _ cs) = cs
conds _ = []
removeEmpty :: PQ.PrimQuery' a -> Maybe (PQ.PrimQuery' b)
removeEmpty = PQ.foldPrimQuery PQ.PrimQueryFold {
PQ.unit = return PQ.Unit
, PQ.empty = const Nothing
, PQ.baseTable = return .: PQ.BaseTable
, PQ.product = let sequenceOf l = traverseOf l id
traverseOf = id
_2 = traverse
in
\x y -> PQ.Product <$> sequenceOf (traverse._2) x
<*> pure y
, PQ.aggregate = fmap . PQ.Aggregate
, PQ.distinctOnOrderBy = \mDistinctOns -> fmap . PQ.DistinctOnOrderBy mDistinctOns
, PQ.limit = fmap . PQ.Limit
, PQ.join = \jt pe pes1 pes2 pq1 pq2 -> PQ.Join jt pe pes1 pes2 <$> pq1 <*> pq2
, PQ.existsf = \b pq1 pq2 -> PQ.Exists b <$> pq1 <*> pq2
, PQ.values = return .: PQ.Values
, PQ.binary = \case
PQ.Except -> binary Just (const Nothing) PQ.Except
PQ.Union -> binary Just Just PQ.Union
PQ.Intersect -> binary (const Nothing) (const Nothing) PQ.Intersect
PQ.ExceptAll -> binary Just (const Nothing) PQ.ExceptAll
PQ.UnionAll -> binary Just Just PQ.UnionAll
PQ.IntersectAll -> binary (const Nothing) (const Nothing) PQ.IntersectAll
, PQ.label = fmap . PQ.Label
, PQ.relExpr = return .: PQ.RelExpr
, PQ.rebind = \b -> fmap . PQ.Rebind b
}
where
binary n1 n2 jj = \case
(Nothing, Nothing) -> Nothing
(Nothing, Just pq2) -> n2 pq2
(Just pq1, Nothing) -> n1 pq1
(Just pq1, Just pq2) -> Just (PQ.Binary jj (pq1, pq2))