{-# LANGUAGE LambdaCase #-}

module Opaleye.Internal.Optimize where

import           Prelude hiding (product)

import qualified Opaleye.Internal.PrimQuery as PQ
import           Opaleye.Internal.Helpers   ((.:))

import qualified Data.List.NonEmpty as NEL
import           Data.Semigroup ((<>))

import           Control.Applicative ((<$>), (<*>), pure)
import           Control.Arrow (first)

optimize :: PQ.PrimQuery' a -> PQ.PrimQuery' a
optimize = mergeProduct . removeUnit

removeUnit :: PQ.PrimQuery' a -> PQ.PrimQuery' a
removeUnit = PQ.foldPrimQuery PQ.primQueryFoldDefault { PQ.product   = product }
  where product pqs = PQ.Product pqs'
          where pqs' = case NEL.nonEmpty (NEL.filter (not . PQ.isUnit . snd) pqs) of
                         Nothing -> return (pure PQ.Unit)
                         Just xs -> xs

mergeProduct :: PQ.PrimQuery' a -> PQ.PrimQuery' a
mergeProduct = PQ.foldPrimQuery PQ.primQueryFoldDefault { PQ.product   = product }
  where product pqs pes = PQ.Product pqs' (pes ++ pes')
          where pqs' = pqs >>= queries
                queries (lat, PQ.Product qs _) = fmap (first (lat <>)) qs
                queries q = return q
                pes' = NEL.toList pqs >>= conds
                conds (_lat, PQ.Product _ cs) = cs
                conds _ = []

removeEmpty :: PQ.PrimQuery' a -> Maybe (PQ.PrimQuery' b)
removeEmpty = PQ.foldPrimQuery PQ.PrimQueryFold {
    PQ.unit      = return PQ.Unit
  , PQ.empty     = const Nothing
  , PQ.baseTable = return .: PQ.BaseTable
  , PQ.product   = let sequenceOf l = traverseOf l id
                       traverseOf = id
                       _2 = traverse
                   in
                   \x y -> PQ.Product <$> sequenceOf (traverse._2) x
                                      <*> pure y
  , PQ.aggregate = fmap . PQ.Aggregate
  , PQ.distinctOnOrderBy = \mDistinctOns -> fmap . PQ.DistinctOnOrderBy mDistinctOns
  , PQ.limit     = fmap . PQ.Limit
  , PQ.join      = \jt pe pes1 pes2 pq1 pq2 -> PQ.Join jt pe pes1 pes2 <$> pq1 <*> pq2
  , PQ.existsf   = \b pq1 pq2 -> PQ.Exists b <$> pq1 <*> pq2
  , PQ.values    = return .: PQ.Values
  , PQ.binary    = \case
      -- Some unfortunate duplication here
      PQ.Except       -> binary Just            (const Nothing) PQ.Except
      PQ.Union        -> binary Just            Just            PQ.Union
      PQ.Intersect    -> binary (const Nothing) (const Nothing) PQ.Intersect

      PQ.ExceptAll    -> binary Just            (const Nothing) PQ.ExceptAll
      PQ.UnionAll     -> binary Just            Just            PQ.UnionAll
      PQ.IntersectAll -> binary (const Nothing) (const Nothing) PQ.IntersectAll
  , PQ.label     = fmap . PQ.Label
  , PQ.relExpr   = return .: PQ.RelExpr
  , PQ.rebind    = \b -> fmap . PQ.Rebind b
  }
  where -- If only the first argument is Just, do n1 on it
        -- If only the second argument is Just, do n2 on it
        binary n1 n2 jj = \case
          (Nothing, Nothing)   -> Nothing
          (Nothing, Just pq2)  -> n2 pq2
          (Just pq1, Nothing)  -> n1 pq1
          (Just pq1, Just pq2) -> Just (PQ.Binary jj (pq1, pq2))