{-# LANGUAGE RecordWildCards, TupleSections, ViewPatterns, RankNTypes, TypeOperators, TypeFamilies, ExistentialQuantification #-} {-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleInstances, FlexibleContexts, ScopedTypeVariables #-} module General.Database( Pred, (%==), (%==%), (%>), (%<), (%/=), (%&&), nullP, likeP, orderDesc, orderAsc, distinct, limit, Upd(..), TypeField(..), Table, table, Column, column, rowid, norowid, sqlInsert, sqlUpdate, sqlSelect, sqlDelete, sqlEnsureTable, sqlUnsafe ) where import Data.List.Extra import Data.String import Data.Maybe import Data.Time.Clock import Data.Tuple.Extra import Database.SQLite.Simple hiding ((:=)) import Database.SQLite.Simple.FromField import Database.SQLite.Simple.ToField type family Uncolumns cs type instance Uncolumns () = () type instance Uncolumns (Column a) = Only a type instance Uncolumns (Only (Column a)) = Only a type instance Uncolumns (Column a, Column b) = (a, b) type instance Uncolumns (Column a, Column b, Column c) = (a, b, c) type instance Uncolumns (Column a, Column b, Column c, Column d) = (a, b, c, d) type instance Uncolumns (Column a, Column b, Column c, Column d, Column e) = (a, b, c, d, e) type instance Uncolumns (Column a, Column b, Column c, Column d, Column e, Column f) = (a, b, c, d, e, f) type instance Uncolumns (Column a, Column b, Column c, Column d, Column e, Column f, Column g) = (a, b, c, d, e, f, g) type instance Uncolumns (Column a, Column b, Column c, Column d, Column e, Column f, Column g, Column h) = (a, b, c, d, e, f, g, h) type instance Uncolumns (Column a, Column b, Column c, Column d, Column e, Column f, Column g, Column h, Column i) = (a, b, c, d, e, f, g, h, i) data Table rowid cs = Table {tblName :: String, tblKeys :: [Column_], tblCols :: [Column_]} data Column c = Column {colTable :: String, colName :: String, colSqlType :: String} deriving (Eq,Show) type Column_ = Column () column_ :: Column c -> Column_ column_ Column{..} = Column{..} class TypeField field where typeField :: field -> String instance TypeField String where typeField _ = "TEXT NOT NULL" instance TypeField Int where typeField _ = "INTEGER NOT NULL" instance TypeField Double where typeField _ = "REAL NOT NULL" instance TypeField UTCTime where typeField _ = "TEXT NOT NULL" instance TypeField Bool where typeField _ = "INTEGER NOT NULL" instance TypeField a => TypeField (Maybe a) where typeField x | Just s <- stripSuffix " NOT NULL" s = s | otherwise = error "Can't remove the NULL constraint" where s = typeField $ fromJust x class Columns cs where columns :: cs -> [Column_] instance Columns () where columns () = [] instance Columns (Column c1) where columns c1 = [column_ c1] instance Columns (Only (Column c1)) where columns (Only c1) = [column_ c1] instance Columns (Column c1, Column c2) where columns (c1, c2) = [column_ c1, column_ c2] instance Columns (Column c1, Column c2, Column c3) where columns (c1, c2, c3) = [column_ c1, column_ c2, column_ c3] instance Columns (Column c1, Column c2, Column c3, Column c4) where columns (c1, c2, c3, c4) = [column_ c1, column_ c2, column_ c3, column_ c4] instance Columns (Column c1, Column c2, Column c3, Column c4, Column c5) where columns (c1, c2, c3, c4, c5) = [column_ c1, column_ c2, column_ c3, column_ c4, column_ c5] instance Columns (Column c1, Column c2, Column c3, Column c4, Column c5, Column c6) where columns (c1, c2, c3, c4, c5, c6) = [column_ c1, column_ c2, column_ c3, column_ c4, column_ c5, column_ c6] instance Columns (Column c1, Column c2, Column c3, Column c4, Column c5, Column c6, Column c7) where columns (c1, c2, c3, c4, c5, c6, c7) = [column_ c1, column_ c2, column_ c3, column_ c4, column_ c5, column_ c6, column_ c7] instance Columns (Column c1, Column c2, Column c3, Column c4, Column c5, Column c6, Column c7, Column c8) where columns (c1, c2, c3, c4, c5, c6, c7, c8) = [column_ c1, column_ c2, column_ c3, column_ c4, column_ c5, column_ c6, column_ c7, column_ c8] instance Columns (Column c1, Column c2, Column c3, Column c4, Column c5, Column c6, Column c7, Column c8, Column c9) where columns (c1, c2, c3, c4, c5, c6, c7, c8, c9) = [column_ c1, column_ c2, column_ c3, column_ c4, column_ c5, column_ c6, column_ c7, column_ c8, column_ c9] table :: (Columns keys, Columns cols) => String -> Column rowid -> keys -> cols -> Table rowid (Uncolumns cols) -- important to produce name before looking at columns table name rowid (columns -> keys) (columns -> cols) = Table name (check keys) (check cols) where check x | nubOrd (map colTable $ keys ++ cols) /= [name] = error "Column with the wrong table" | not $ null $ map colName keys \\ map colName cols = error "Key column which is not one of the normal columns" | colName rowid `notElem` ["","rowid"] = error "Rowid column must have name rowid" | otherwise = x column :: forall c rowid cs . TypeField c => Table rowid cs -> String -> Column c column tbl row = Column (tblName tbl) row (typeField (undefined :: c)) rowid :: Table rowid cs -> Column rowid rowid tbl = Column (tblName tbl) "rowid" "" norowid :: Column () norowid = Column "" "" "" sqlInsert :: (ToRow cs, FromField rowid) => Connection -> Table rowid cs -> cs -> IO rowid sqlInsert conn tbl val = do let vs = toRow val -- FIXME: Should combine the last_insert_rowid with the INSERT INTO let str = "INSERT INTO " ++ tblName tbl ++ " VALUES (" ++ intercalate "," (replicate (length vs) "?") ++ ")" execute conn (fromString str) vs [Only row] <- query_ conn (fromString "SELECT last_insert_rowid()") return row sqlUpdate :: Connection -> [Upd] -> [Pred] -> IO () sqlUpdate conn upd pred = do let (updCs, updVs) = unzip $ map unupdate upd let (prdStr, _, prdCs, prdVs) = unpred pred let tbl = nubOrd $ map colTable $ updCs ++ prdCs case tbl of _ | null upd -> fail "Must update at least one column" [t] -> do let str = "UPDATE " ++ t ++ " SET " ++ intercalate ", " (map ((++ "=?") . colName) updCs) ++ " WHERE " ++ prdStr execute conn (fromString str) (updVs ++ prdVs) _ -> fail "Must update all in the same column" sqlDelete :: Connection -> Table rowid cs -> [Pred] -> IO () sqlDelete conn tbl pred = do let (prdStr, _, prdCs, prdVs) = unpred pred case nubOrd $ tblName tbl : map colTable prdCs of [t] -> do let str = "DELETE FROM " ++ t ++ " WHERE " ++ prdStr execute conn (fromString str) prdVs ts -> fail $ "sqlDelete, can only delete from one table but you are touching: " ++ unwords ts sqlSelect :: (FromRow (Uncolumns cs), Columns cs) => Connection -> cs -> [Pred] -> IO [Uncolumns cs] sqlSelect conn cols pred = do let outCs = columns cols let (prdStr, prdDs, prdCs, prdVs) = unpred pred let str = "SELECT " ++ intercalate ", " [(if c `elem` prdDs then "DISTINCT " else "") ++ colTable ++ "." ++ colName | c@Column{..} <- outCs] ++ " " ++ "FROM " ++ intercalate ", " (nubOrd $ map colTable $ outCs ++ prdCs) ++ " WHERE " ++ prdStr query conn (fromString str) prdVs sqlEnsureTable :: Connection -> Table rowid cs -> IO () sqlEnsureTable conn Table{..} = do let fields = intercalate ", " $ [colName ++ " " ++ colSqlType | Column{..} <- tblCols] ++ ["PRIMARY KEY (" ++ intercalate ", " (map colName tblKeys) ++ ")" | not $ null tblKeys] let str = "CREATE TABLE " ++ tblName ++ " (" ++ fields ++ ")" existing <- query conn (fromString "SELECT sql FROM sqlite_master WHERE type = ? AND name = ?") ("table", tblName) case existing of [Only s] | str == s -> return () [] -> execute_ conn $ fromString str _ -> error $ "Trying to ensure table " ++ tblName ++ " but mismatch" ++ "\nCreating:\n" ++ str ++ "\nGot:\n" ++ unlines (map fromOnly existing) sqlUnsafe :: (ToRow q, FromRow r) => Connection -> String -> q -> IO [r] sqlUnsafe conn str q = query conn (fromString str) q data Upd = forall a . ToField a => Column a := a unupdate :: Upd -> (Column_, SQLData) unupdate (c := v) = (column_ c, toField v) data Pred = PNull Column_ | PNotNull Column_ | PEq Column_ SQLData | PNEq Column_ SQLData | PGt Column_ SQLData | PLt Column_ SQLData | PEqP Column_ Column_ | PLike Column_ SQLData | PAnd [Pred] | PDistinct Column_ | POrder Column_ String | PLimit Int distinct :: Column c -> Pred distinct c = PDistinct (column_ c) limit :: Int -> Pred limit = PLimit orderDesc :: Column UTCTime -> Pred orderDesc c = POrder (column_ c) $ colTable c ++ "." ++ colName c ++ " DESC" orderAsc :: Column UTCTime -> Pred orderAsc c = POrder (column_ c) $ colTable c ++ "." ++ colName c ++ " ASC" nullP :: Column (Maybe c) -> Pred nullP c = PNull (column_ c) likeP :: ToField c => Column c -> c -> Pred likeP (column_ -> c) (toField -> v) = PLike c v (%&&) :: Pred -> Pred -> Pred (%&&) a b = PAnd [a,b] (%==) :: ToField c => Column c -> c -> Pred (%==) (column_ -> c) (toField -> v) | v == SQLNull = PNull c | otherwise = PEq c v (%>) :: ToField c => Column c -> c -> Pred (%>) (column_ -> c) (toField -> v) | v == SQLNull = error $ "Can't %> on a NULL" | otherwise = PGt c v (%<) :: ToField c => Column c -> c -> Pred (%<) (column_ -> c) (toField -> v) | v == SQLNull = error $ "Can't %> on a NULL" | otherwise = PLt c v (%/=) :: ToField c => Column c -> c -> Pred (%/=) (column_ -> c) (toField -> v) | v == SQLNull = PNotNull c | otherwise = PNEq c v (%==%) :: ToField c => Column c -> Column c -> Pred (%==%) c1 c2 | isNull c1 || isNull c2 = error $ show ("Column must be NOT NULL to do %==%", show c1, show c2) | otherwise = PEqP (column_ c1) (column_ c2) where isNull c = not $ colSqlType c == "" || " NOT NULL" `isSuffixOf` colSqlType c unpred :: [Pred] -> (String, [Column_], [Column_], [SQLData]) unpred ps = let (a,b,c) = f $ PAnd pred in (a ++ (if null order then "" else " ORDER BY " ++ unwords [x | POrder _ x <- order]) ++ (if null limit then "" else " LIMIT " ++ head [show x | PLimit x <- limit]) , [x | PDistinct x <- dist], b ++ [x | POrder x _ <- order], c) where isDistinct PDistinct{} = True; isDistinct _ = False isOrder POrder{} = True; isOrder _ = False isLimit PLimit{} = True; isLimit _ = False (dist, (order, (limit, pred))) = second (second (partition isLimit) . partition isOrder) $ partition isDistinct ps g Column{..} = colTable ++ "." ++ colName f (PNull c) = (g c ++ " IS NULL", [c], []) f (PNotNull c) = (g c ++ " IS NOT NULL", [c], []) f (PEq c v) = (g c ++ " == ?", [c], [v]) -- IS always works, but is a LOT slower f (PNEq c v) = (g c ++ " != ?", [c], [v]) -- IS always works, but is a LOT slower f (PGt c v) = (g c ++ " > ?", [c], [v]) -- IS always works, but is a LOT slower f (PLt c v) = (g c ++ " < ?", [c], [v]) -- IS always works, but is a LOT slower f (PEqP c1 c2) = (g c1 ++ " = " ++ g c2, [c1,c2], []) f (PLike c v) = (g c ++ " LIKE ?", [c], [v]) f (PAnd []) = ("NULL IS NULL", [], []) f (PAnd [x]) = f x f (PAnd xs) = (intercalate " AND " ["(" ++ s ++ ")" | s <- ss], concat cs, concat vs) where (ss,cs,vs) = unzip3 $ map f xs f _ = error "Unrecognised Pred" instance FromField () where fromField _ = return ()