{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Arrows #-}

module Opaleye.Internal.Join where

import qualified Opaleye.Internal.HaskellDB.PrimQuery as HPQ
import qualified Opaleye.Internal.PackMap             as PM
import qualified Opaleye.Internal.Tag                 as T
import qualified Opaleye.Internal.Unpackspec          as U
import           Opaleye.Internal.Column (Field_(Column), FieldNullable)
import qualified Opaleye.Internal.QueryArr as Q
import qualified Opaleye.Internal.Operators as Op
import qualified Opaleye.Internal.PrimQuery as PQ
import qualified Opaleye.Internal.PGTypesExternal as T
import qualified Opaleye.Internal.Rebind as Rebind
import qualified Opaleye.SqlTypes as T
import qualified Opaleye.Field as C
import           Opaleye.Field   (Field)
import           Opaleye.Internal.MaybeFields (MaybeFields(MaybeFields),
                                               mfPresent, mfFields)
import qualified Opaleye.Select  as S

import qualified Control.Applicative as A
import qualified Control.Arrow

import           Data.Profunctor (Profunctor, dimap)
import qualified Data.Profunctor.Product as PP
import qualified Data.Profunctor.Product.Default as D

newtype NullMaker a b = NullMaker (a -> b)

toNullable :: NullMaker a b -> a -> b
toNullable :: forall a b. NullMaker a b -> a -> b
toNullable (NullMaker a -> b
f) = a -> b
f

instance D.Default NullMaker (Field a) (FieldNullable a) where
  def :: NullMaker (Field a) (FieldNullable a)
def = (Field a -> FieldNullable a)
-> NullMaker (Field a) (FieldNullable a)
forall a b. (a -> b) -> NullMaker a b
NullMaker Field a -> FieldNullable a
forall a. Field a -> FieldNullable a
C.toNullable

instance D.Default NullMaker (FieldNullable a) (FieldNullable a) where
  def :: NullMaker (FieldNullable a) (FieldNullable a)
def = (FieldNullable a -> FieldNullable a)
-> NullMaker (FieldNullable a) (FieldNullable a)
forall a b. (a -> b) -> NullMaker a b
NullMaker FieldNullable a -> FieldNullable a
forall a. a -> a
id

joinExplicit :: U.Unpackspec columnsA columnsA
             -> U.Unpackspec columnsB columnsB
             -> (columnsA -> returnedColumnsA)
             -> (columnsB -> returnedColumnsB)
             -> PQ.JoinType
             -> Q.Query columnsA -> Q.Query columnsB
             -> ((columnsA, columnsB) -> Field T.PGBool)
             -> Q.Query (returnedColumnsA, returnedColumnsB)
joinExplicit :: forall columnsA columnsB returnedColumnsA returnedColumnsB.
Unpackspec columnsA columnsA
-> Unpackspec columnsB columnsB
-> (columnsA -> returnedColumnsA)
-> (columnsB -> returnedColumnsB)
-> JoinType
-> Query columnsA
-> Query columnsB
-> ((columnsA, columnsB) -> Field PGBool)
-> Query (returnedColumnsA, returnedColumnsB)
joinExplicit Unpackspec columnsA columnsA
uA Unpackspec columnsB columnsB
uB columnsA -> returnedColumnsA
returnColumnsA columnsB -> returnedColumnsB
returnColumnsB JoinType
joinType
             Query columnsA
qA Query columnsB
qB (columnsA, columnsB) -> Field PGBool
cond = State Tag ((returnedColumnsA, returnedColumnsB), PrimQuery)
-> Query (returnedColumnsA, returnedColumnsB)
forall a. State Tag (a, PrimQuery) -> Query a
Q.productQueryArr (State Tag ((returnedColumnsA, returnedColumnsB), PrimQuery)
 -> Query (returnedColumnsA, returnedColumnsB))
-> State Tag ((returnedColumnsA, returnedColumnsB), PrimQuery)
-> Query (returnedColumnsA, returnedColumnsB)
forall a b. (a -> b) -> a -> b
$ do
  (columnsA
columnsA, PrimQuery
primQueryA) <- Query columnsA -> State Tag (columnsA, PrimQuery)
forall a. Select a -> State Tag (a, PrimQuery)
Q.runSimpleSelect Query columnsA
qA
  (columnsB
columnsB, PrimQuery
primQueryB) <- Query columnsB -> State Tag (columnsB, PrimQuery)
forall a. Select a -> State Tag (a, PrimQuery)
Q.runSimpleSelect Query columnsB
qB

  Tag
endTag <- State Tag Tag
T.fresh

  let (columnsA
newColumnsA, [(Symbol, PrimExpr)]
ljPEsA) =
            PM [(Symbol, PrimExpr)] columnsA
-> (columnsA, [(Symbol, PrimExpr)])
forall a r. PM [a] r -> (r, [a])
PM.run (Unpackspec columnsA columnsA
-> (PrimExpr
    -> StateT ([(Symbol, PrimExpr)], Int) Identity PrimExpr)
-> columnsA
-> PM [(Symbol, PrimExpr)] columnsA
forall (f :: * -> *) columns b.
Applicative f =>
Unpackspec columns b -> (PrimExpr -> f PrimExpr) -> columns -> f b
U.runUnpackspec Unpackspec columnsA columnsA
uA (Int
-> Tag
-> PrimExpr
-> StateT ([(Symbol, PrimExpr)], Int) Identity PrimExpr
extractLeftJoinFields Int
1 Tag
endTag) columnsA
columnsA)
      (columnsB
newColumnsB, [(Symbol, PrimExpr)]
ljPEsB) =
            PM [(Symbol, PrimExpr)] columnsB
-> (columnsB, [(Symbol, PrimExpr)])
forall a r. PM [a] r -> (r, [a])
PM.run (Unpackspec columnsB columnsB
-> (PrimExpr
    -> StateT ([(Symbol, PrimExpr)], Int) Identity PrimExpr)
-> columnsB
-> PM [(Symbol, PrimExpr)] columnsB
forall (f :: * -> *) columns b.
Applicative f =>
Unpackspec columns b -> (PrimExpr -> f PrimExpr) -> columns -> f b
U.runUnpackspec Unpackspec columnsB columnsB
uB (Int
-> Tag
-> PrimExpr
-> StateT ([(Symbol, PrimExpr)], Int) Identity PrimExpr
extractLeftJoinFields Int
2 Tag
endTag) columnsB
columnsB)

      nullableColumnsA :: returnedColumnsA
nullableColumnsA = columnsA -> returnedColumnsA
returnColumnsA columnsA
newColumnsA
      nullableColumnsB :: returnedColumnsB
nullableColumnsB = columnsB -> returnedColumnsB
returnColumnsB columnsB
newColumnsB

      Column PrimExpr
cond' = (columnsA, columnsB) -> Field PGBool
cond (columnsA
columnsA, columnsB
columnsB)
      primQueryR :: PrimQuery
primQueryR = JoinType
-> PrimExpr
-> (Lateral, PrimQuery)
-> (Lateral, PrimQuery)
-> PrimQuery
forall a.
JoinType
-> PrimExpr
-> (Lateral, PrimQuery' a)
-> (Lateral, PrimQuery' a)
-> PrimQuery' a
PQ.Join JoinType
joinType PrimExpr
cond'
                               (Lateral
PQ.NonLateral, (Bool -> [(Symbol, PrimExpr)] -> PrimQuery -> PrimQuery
forall a.
Bool -> [(Symbol, PrimExpr)] -> PrimQuery' a -> PrimQuery' a
PQ.Rebind Bool
True [(Symbol, PrimExpr)]
ljPEsA PrimQuery
primQueryA))
                               (Lateral
PQ.NonLateral, (Bool -> [(Symbol, PrimExpr)] -> PrimQuery -> PrimQuery
forall a.
Bool -> [(Symbol, PrimExpr)] -> PrimQuery' a -> PrimQuery' a
PQ.Rebind Bool
True [(Symbol, PrimExpr)]
ljPEsB PrimQuery
primQueryB))

  ((returnedColumnsA, returnedColumnsB), PrimQuery)
-> State Tag ((returnedColumnsA, returnedColumnsB), PrimQuery)
forall a. a -> StateT Tag Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((returnedColumnsA
nullableColumnsA, returnedColumnsB
nullableColumnsB), PrimQuery
primQueryR)


leftJoinAExplicit :: U.Unpackspec a a
                  -> NullMaker a nullableA
                  -> Q.Query a
                  -> Q.QueryArr (a -> Field T.PGBool) nullableA
leftJoinAExplicit :: forall a nullableA.
Unpackspec a a
-> NullMaker a nullableA
-> Query a
-> QueryArr (a -> Field PGBool) nullableA
leftJoinAExplicit Unpackspec a a
uA NullMaker a nullableA
nullmaker Query a
rq =
  State Tag ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
-> QueryArr (a -> Field PGBool) nullableA
forall a b.
State Tag (a -> (b, PrimExpr, PrimQuery)) -> QueryArr a b
Q.leftJoinQueryArr' (State
   Tag ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
 -> QueryArr (a -> Field PGBool) nullableA)
-> State
     Tag ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
-> QueryArr (a -> Field PGBool) nullableA
forall a b. (a -> b) -> a -> b
$ do
    (a
newColumnsR, PrimQuery
right) <- Query a -> State Tag (a, PrimQuery)
forall a. Select a -> State Tag (a, PrimQuery)
Q.runSimpleSelect (Query a -> State Tag (a, PrimQuery))
-> Query a -> State Tag (a, PrimQuery)
forall a b. (a -> b) -> a -> b
$ proc () -> do
          a
a <- Query a
rq -< ()
          Unpackspec a a -> SelectArr a a
forall a b. Unpackspec a b -> SelectArr a b
Rebind.rebindExplicit Unpackspec a a
uA -< a
a
    ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
-> State
     Tag ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
forall a. a -> StateT Tag Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
 -> State
      Tag ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery)))
-> ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
-> State
     Tag ((a -> Field PGBool) -> (nullableA, PrimExpr, PrimQuery))
forall a b. (a -> b) -> a -> b
$ \a -> Field PGBool
p ->
      let renamedNullable :: nullableA
renamedNullable = NullMaker a nullableA -> a -> nullableA
forall a b. NullMaker a b -> a -> b
toNullable NullMaker a nullableA
nullmaker a
newColumnsR
          Column PrimExpr
cond = a -> Field PGBool
p a
newColumnsR
      in (nullableA
renamedNullable, PrimExpr
cond, PrimQuery
right)

optionalRestrict :: D.Default U.Unpackspec a a
                 => S.Select a
                 -> S.SelectArr (a -> Field T.SqlBool) (MaybeFields a)
optionalRestrict :: forall a.
Default Unpackspec a a =>
Select a -> SelectArr (a -> Field PGBool) (MaybeFields a)
optionalRestrict = Unpackspec a a
-> Select a -> SelectArr (a -> Field PGBool) (MaybeFields a)
forall a.
Unpackspec a a
-> Select a -> SelectArr (a -> Field PGBool) (MaybeFields a)
optionalRestrictExplicit Unpackspec a a
forall (p :: * -> * -> *) a b. Default p a b => p a b
D.def

optionalRestrictExplicit :: U.Unpackspec a a
                         -> S.Select a
                         -> S.SelectArr (a -> Field T.SqlBool) (MaybeFields a)
optionalRestrictExplicit :: forall a.
Unpackspec a a
-> Select a -> SelectArr (a -> Field PGBool) (MaybeFields a)
optionalRestrictExplicit Unpackspec a a
uA Select a
q =
  ((a -> Field PGBool) -> (Field PGBool, a) -> Field PGBool)
-> ((Field PGBool, a) -> MaybeFields a)
-> SelectArr ((Field PGBool, a) -> Field PGBool) (Field PGBool, a)
-> SelectArr (a -> Field PGBool) (MaybeFields a)
forall a b c d.
(a -> b) -> (c -> d) -> SelectArr b c -> SelectArr a d
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap ((a -> Field PGBool)
-> ((Field PGBool, a) -> a) -> (Field PGBool, a) -> Field PGBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Field PGBool, a) -> a
forall a b. (a, b) -> b
snd) (\(Field PGBool
nonNullIfPresent, a
rest) ->
      let present :: Field PGBool
present = Field PGBool -> Field PGBool
Op.not (FieldNullable Any -> Field PGBool
forall a. FieldNullable a -> Field PGBool
C.isNull (Field PGBool -> FieldNullable Any
forall (n :: Nullability) a (n' :: Nullability) b.
Field_ n a -> Field_ n' b
C.unsafeCoerceField Field PGBool
nonNullIfPresent))
      in MaybeFields { mfPresent :: Column PGBool
mfPresent = Column PGBool
Field PGBool
present
                     , mfFields :: a
mfFields  = a
rest
                     }) (SelectArr ((Field PGBool, a) -> Field PGBool) (Field PGBool, a)
 -> SelectArr (a -> Field PGBool) (MaybeFields a))
-> SelectArr ((Field PGBool, a) -> Field PGBool) (Field PGBool, a)
-> SelectArr (a -> Field PGBool) (MaybeFields a)
forall a b. (a -> b) -> a -> b
$
  Unpackspec (Field PGBool, a) (Field PGBool, a)
-> NullMaker (Field PGBool, a) (Field PGBool, a)
-> Query (Field PGBool, a)
-> SelectArr ((Field PGBool, a) -> Field PGBool) (Field PGBool, a)
forall a nullableA.
Unpackspec a a
-> NullMaker a nullableA
-> Query a
-> QueryArr (a -> Field PGBool) nullableA
leftJoinAExplicit ((Unpackspec (Field PGBool) (Field PGBool), Unpackspec a a)
-> Unpackspec (Field PGBool, a) (Field PGBool, a)
forall (p :: * -> * -> *) a0 a1 b0 b1.
ProductProfunctor p =>
(p a0 b0, p a1 b1) -> p (a0, a1) (b0, b1)
PP.p2 (Unpackspec (Field PGBool) (Field PGBool)
forall (n :: Nullability) a. Unpackspec (Field_ n a) (Field_ n a)
U.unpackspecField, Unpackspec a a
uA))
                    (((Field PGBool, a) -> (Field PGBool, a))
-> NullMaker (Field PGBool, a) (Field PGBool, a)
forall a b. (a -> b) -> NullMaker a b
Opaleye.Internal.Join.NullMaker (Field PGBool, a) -> (Field PGBool, a)
forall a. a -> a
id)
                    ((a -> (Field PGBool, a)) -> Select a -> Query (Field PGBool, a)
forall a b. (a -> b) -> SelectArr () a -> SelectArr () b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
x -> (Bool -> Field PGBool
T.sqlBool Bool
True, a
x)) Select a
q)

-- | An example to demonstrate how the functionality of @LEFT JOIN@
-- can be recovered using 'optionalRestrict'.
leftJoinInTermsOfOptionalRestrict :: D.Default U.Unpackspec fieldsR fieldsR
                                  => S.Select fieldsL
                                  -> S.Select fieldsR
                                  -> ((fieldsL, fieldsR) -> Field T.SqlBool)
                                  -> S.Select (fieldsL, MaybeFields fieldsR)
leftJoinInTermsOfOptionalRestrict :: forall fieldsR fieldsL.
Default Unpackspec fieldsR fieldsR =>
Select fieldsL
-> Select fieldsR
-> ((fieldsL, fieldsR) -> Field PGBool)
-> Select (fieldsL, MaybeFields fieldsR)
leftJoinInTermsOfOptionalRestrict Select fieldsL
qL Select fieldsR
qR (fieldsL, fieldsR) -> Field PGBool
cond = proc () -> do
  fieldsL
fieldsL <- Select fieldsL
qL -< ()
  MaybeFields fieldsR
maybeFieldsR <- Select fieldsR
-> SelectArr (fieldsR -> Field PGBool) (MaybeFields fieldsR)
forall a.
Default Unpackspec a a =>
Select a -> SelectArr (a -> Field PGBool) (MaybeFields a)
optionalRestrict Select fieldsR
qR -< ((fieldsL, fieldsR) -> Field PGBool)
-> fieldsL -> fieldsR -> Field PGBool
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (fieldsL, fieldsR) -> Field PGBool
cond fieldsL
fieldsL
  SelectArr
  (fieldsL, MaybeFields fieldsR) (fieldsL, MaybeFields fieldsR)
forall (a :: * -> * -> *) b. Arrow a => a b b
Control.Arrow.returnA -< (fieldsL
fieldsL, MaybeFields fieldsR
maybeFieldsR)

extractLeftJoinFields :: Int
                      -> T.Tag
                      -> HPQ.PrimExpr
                      -> PM.PM [(HPQ.Symbol, HPQ.PrimExpr)] HPQ.PrimExpr
extractLeftJoinFields :: Int
-> Tag
-> PrimExpr
-> StateT ([(Symbol, PrimExpr)], Int) Identity PrimExpr
extractLeftJoinFields Int
n = String
-> Tag
-> PrimExpr
-> StateT ([(Symbol, PrimExpr)], Int) Identity PrimExpr
forall primExpr.
String -> Tag -> primExpr -> PM [(Symbol, primExpr)] PrimExpr
PM.extractAttr (String
"result" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_")

-- { Boilerplate instances

instance Functor (NullMaker a) where
  fmap :: forall a b. (a -> b) -> NullMaker a a -> NullMaker a b
fmap a -> b
f (NullMaker a -> a
g) = (a -> b) -> NullMaker a b
forall a b. (a -> b) -> NullMaker a b
NullMaker ((a -> b) -> (a -> a) -> a -> b
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f a -> a
g)

instance A.Applicative (NullMaker a) where
  pure :: forall a. a -> NullMaker a a
pure = (a -> a) -> NullMaker a a
forall a b. (a -> b) -> NullMaker a b
NullMaker ((a -> a) -> NullMaker a a) -> (a -> a -> a) -> a -> NullMaker a a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a -> a
forall a. a -> a -> a
forall (f :: * -> *) a. Applicative f => a -> f a
A.pure
  NullMaker a -> a -> b
f <*> :: forall a b. NullMaker a (a -> b) -> NullMaker a a -> NullMaker a b
<*> NullMaker a -> a
x = (a -> b) -> NullMaker a b
forall a b. (a -> b) -> NullMaker a b
NullMaker (a -> a -> b
f (a -> a -> b) -> (a -> a) -> a -> b
forall a b. (a -> a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
A.<*> a -> a
x)

instance Profunctor NullMaker where
  dimap :: forall a b c d.
(a -> b) -> (c -> d) -> NullMaker b c -> NullMaker a d
dimap a -> b
f c -> d
g (NullMaker b -> c
h) = (a -> d) -> NullMaker a d
forall a b. (a -> b) -> NullMaker a b
NullMaker ((a -> b) -> (c -> d) -> (b -> c) -> a -> d
forall a b c d. (a -> b) -> (c -> d) -> (b -> c) -> a -> d
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap a -> b
f c -> d
g b -> c
h)

instance PP.ProductProfunctor NullMaker where
  purePP :: forall b a. b -> NullMaker a b
purePP = b -> NullMaker a b
forall a. a -> NullMaker a a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  **** :: forall a a b.
NullMaker a (a -> b) -> NullMaker a a -> NullMaker a b
(****) = NullMaker a (b -> c) -> NullMaker a b -> NullMaker a c
forall a b. NullMaker a (a -> b) -> NullMaker a a -> NullMaker a b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>)