{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE CPP #-} module Database.Beam.Query.CTE where import Database.Beam.Backend.SQL import Database.Beam.Query.Internal import Database.Beam.Query.Types import Control.Monad.Free.Church import Control.Monad.Writer hiding ((<>)) import Control.Monad.State.Strict import Data.Text (Text) import Data.String import Data.Proxy (Proxy(Proxy)) #if !MIN_VERSION_base(4, 11, 0) import Data.Semigroup #endif import Unsafe.Coerce data Recursiveness be where Nonrecursive :: Recursiveness be Recursive :: IsSql99RecursiveCommonTableExpressionSelectSyntax (BeamSqlBackendSelectSyntax be) => Recursiveness be instance Monoid (Recursiveness be) where mempty = Nonrecursive mappend Recursive _ = Recursive mappend _ Recursive = Recursive mappend _ _ = Nonrecursive instance Semigroup (Recursiveness be) where (<>) = mappend newtype With be (db :: (* -> *) -> *) a = With { runWith :: WriterT (Recursiveness be, [ BeamSql99BackendCTESyntax be ]) (State Int) a } deriving (Monad, Applicative, Functor) instance IsSql99RecursiveCommonTableExpressionSelectSyntax (BeamSqlBackendSelectSyntax be) => MonadFix (With be db) where mfix f = With (tell (Recursive, mempty) >> mfix (runWith . f)) data QAnyScope data ReusableQ be db res where ReusableQ :: Proxy res -> (forall s. Proxy s -> Q be db s (WithRewrittenThread QAnyScope s res)) -> ReusableQ be db res reusableForCTE :: forall be res db . ( ThreadRewritable QAnyScope res , Projectible be res , BeamSqlBackend be ) => Text -> ReusableQ be db res reusableForCTE tblNm = ReusableQ (Proxy @res) (\proxyS -> Q $ liftF (QAll (\_ -> fromTable (tableNamed (tableName Nothing tblNm)) . Just . (, Nothing)) (\tblNm' -> fst $ mkFieldNames @be @res (qualifiedField tblNm')) (\_ -> Nothing) (rewriteThread @QAnyScope @res proxyS . snd))) selecting :: forall res be db . ( BeamSql99CommonTableExpressionBackend be, HasQBuilder be , Projectible be res , ThreadRewritable QAnyScope res ) => Q be db QAnyScope res -> With be db (ReusableQ be db res) selecting q = With $ do cteId <- get put (cteId + 1) let tblNm = fromString ("cte" ++ show cteId) (_ :: res, fieldNames) = mkFieldNames @be (qualifiedField tblNm) tell (Nonrecursive, [ cteSubquerySyntax tblNm fieldNames (buildSqlQuery (tblNm <> "_") q) ]) pure (reusableForCTE tblNm) rescopeQ :: QM be db s res -> QM be db s' res rescopeQ = unsafeCoerce reuse :: forall s be db res . ReusableQ be db res -> Q be db s (WithRewrittenThread QAnyScope s res) reuse (ReusableQ _ q) = q (Proxy @s)