Recursive SQL queries, still need to test before use
This commit is contained in:
parent
55945e30f9
commit
76a627385c
4 changed files with 834 additions and 0 deletions
255
src/Database/Persist/Local/Class/PersistQueryRecursive.hs
Normal file
255
src/Database/Persist/Local/Class/PersistQueryRecursive.hs
Normal file
|
@ -0,0 +1,255 @@
|
|||
{- This file is part of Vervis.
|
||||
-
|
||||
- Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
|
||||
-
|
||||
- ♡ Copying is an act of love. Please copy, reuse and share.
|
||||
-
|
||||
- The author(s) have dedicated all copyright and related and neighboring
|
||||
- rights to this software to the public domain worldwide. This software is
|
||||
- distributed without any warranty.
|
||||
-
|
||||
- You should have received a copy of the CC0 Public Domain Dedication along
|
||||
- with this software. If not, see
|
||||
- <http://creativecommons.org/publicdomain/zero/1.0/>.
|
||||
-}
|
||||
|
||||
{- The code is based on PersistQuery. Actually, most of the difference is
|
||||
- slightly different names and 3 additional function parameters.
|
||||
-}
|
||||
|
||||
-- | Recursive queries are performed by taking the output of a recursion step,
|
||||
-- possibly modifying it, and using the result as the input for the next
|
||||
-- recursion step.
|
||||
--
|
||||
-- This module currently provides a single way to perform that recursive step:
|
||||
-- Match between the /id/ column and some other column which has the same type.
|
||||
--
|
||||
-- For example, suppose we have a `Message` type with a `messageParent` field.
|
||||
-- For given messages `a` and `b`, if `messageParent b == Just a` then `b` is a
|
||||
-- reply to `a`. Therefore, all the replies to a given message point to it
|
||||
-- using the `messageParent` field. And there can be replies to replies and so
|
||||
-- on, creating a tree of messages.
|
||||
--
|
||||
-- > Message
|
||||
-- > author PersonId
|
||||
-- > content Text
|
||||
-- > parent MessageId Maybe
|
||||
--
|
||||
-- If we start with a single message and follow the `messageParent` values
|
||||
-- recursively, we'll be able to get a list (or a tree) of the __ancestors__ of
|
||||
-- the message. Our message /a/ may be a reply to some other message /b/, and
|
||||
-- /b/ may be a reply to message /c/ and so on. Eventually, if there are no
|
||||
-- cycles and it's really a tree structure, we'll reach the root message, which
|
||||
-- has no parent.
|
||||
--
|
||||
-- But there's another way to recurse. What if we wanted to find the replies
|
||||
-- for a given message? And the replies of the replies, and so on? In other
|
||||
-- words, the __decendants__ of a given message. Suppose we start with a
|
||||
-- message /a/. We get a list of the replies of /a/, i.e. message whose parent
|
||||
-- is `Just a`. Then we find the replies of those messages, i.e. the replies of
|
||||
-- the replies of /a/. And so on, recursively, until we can't find more replies
|
||||
-- and then we stop.
|
||||
--
|
||||
-- Therefore we can perform the recursion in one of two directions:
|
||||
--
|
||||
-- - __Outwards__, i.e. follow from a message to its parents. More generally,
|
||||
-- given a persistent entity type `Foobar`, follow recursively using a
|
||||
-- specific field of it, whose type is `FoobarId`. It's called "outwards"
|
||||
-- because it's like following out-edges of a graph node, i.e. arrows
|
||||
-- pointing from a node towards other nodes.
|
||||
-- - __Inwards__, i.e. find the children (i.e. replies) of a message, and then
|
||||
-- their children, and so on. More generally, given a persistent entity type
|
||||
-- `Foobar`, find other values referring to it using a specific field, whose
|
||||
-- type is `FoobarId`, and recursive find such values for the results we get
|
||||
-- and so on. It's called "inwards" because it's like following in-edges of a
|
||||
-- graph node, i.e. arrows pointing from other nodes towards that node.
|
||||
--
|
||||
-- The 'RecursionDirection' type is used for specifying the direction.
|
||||
--
|
||||
-- When you follow all the children of an entity recursively, or all of its
|
||||
-- parents, we call the result you get the __transitive closure__ of the
|
||||
-- specific field you used. You can further specify the direction, i.e.
|
||||
-- __outward transitive closure__ or __inward transitive closure__. For
|
||||
-- examples, if you follow a message's parents recursively as in the example
|
||||
-- above, you get an outward transitive closure on the /parent/ field.
|
||||
--
|
||||
-- Note that the definition used here is __not__ the same as the mathematical
|
||||
-- definition. When you perform a recursive query without filters, you get not
|
||||
-- only the ancestors (or the decendants) of an entity, but also the root
|
||||
-- entity itself. In other words, even though a message is not a reply of
|
||||
-- itself, you'll still get it in the query result. If you want to get just the
|
||||
-- ancestors (or decendants), i.e. the actual transitive closer of the "is
|
||||
-- reply of" relation in the mathematical sense, use a filter to omit the root
|
||||
-- message based on the ID, i.e. @[MessageParent /= msgid]@.
|
||||
--
|
||||
-- Therefore, when the term "transitive closure" is used below, it means not
|
||||
-- just the ancestors (or decendants), but also the origin entity too.
|
||||
module Database.Persist.Local.Class.PersistQueryRecursive
|
||||
( RecursionDirection (..)
|
||||
, PersistQueryRecursive (..)
|
||||
, selectRecursivelySource
|
||||
, selectRecursivelyKeys
|
||||
, selectRecursivelyList
|
||||
, selectRecursivelyKeysList
|
||||
)
|
||||
where
|
||||
|
||||
import Prelude
|
||||
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Reader (MonadReader)
|
||||
import Control.Monad.Trans.Reader (ReaderT)
|
||||
import Control.Monad.Trans.Resource (MonadResource, release)
|
||||
import Data.Acquire (Acquire, allocateAcquire, with)
|
||||
import Database.Persist.Class
|
||||
import Database.Persist.Types
|
||||
|
||||
import qualified Data.Conduit as C
|
||||
import qualified Data.Conduit.List as CL
|
||||
|
||||
data RecursionDirection
|
||||
= RecOut
|
||||
| RecIn
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | Backends supporting recursive conditional operations.
|
||||
class PersistQuery backend => PersistQueryRecursive backend where
|
||||
-- | Update individual fields on any record in the transitive closure and
|
||||
-- matching the given criterion.
|
||||
updateRecursivelyWhere
|
||||
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [Update val]
|
||||
-> ReaderT backend m ()
|
||||
|
||||
-- | Delete all records in the transitive closure which match the given
|
||||
-- criterion.
|
||||
deleteRecursivelyWhere
|
||||
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> ReaderT backend m ()
|
||||
|
||||
-- | Get all records in the transitive closure, which match the given
|
||||
-- criterion, in the specified order. Returns also the identifiers.
|
||||
selectRecursivelySourceRes
|
||||
:: ( PersistEntity val
|
||||
, PersistEntityBackend val ~ backend
|
||||
, MonadIO m1
|
||||
, MonadIO m2
|
||||
)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [SelectOpt val]
|
||||
-> ReaderT backend m1 (Acquire (C.Source m2 (Entity val)))
|
||||
|
||||
-- | Get the 'Key's of all records in the transitive closure, which match
|
||||
-- the given criterion.
|
||||
selectRecursivelyKeysRes
|
||||
:: ( PersistEntity val
|
||||
, PersistEntityBackend val ~ backend
|
||||
, MonadIO m1
|
||||
, MonadIO m2
|
||||
)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [SelectOpt val]
|
||||
-> ReaderT backend m1 (Acquire (C.Source m2 (Key val)))
|
||||
|
||||
-- | The total number of records in the transitive closure which fulfill
|
||||
-- the given criterion.
|
||||
countRecursively
|
||||
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> ReaderT backend m Int
|
||||
|
||||
-- | Get all records in the transitive closure, which match the given
|
||||
-- criterion, in the specified order. Returns also the identifiers.
|
||||
selectRecursivelySource
|
||||
:: ( PersistQueryRecursive backend
|
||||
, MonadResource m
|
||||
, PersistEntity val
|
||||
, PersistEntityBackend val ~ backend
|
||||
, MonadReader env m
|
||||
, HasPersistBackend env backend
|
||||
)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [SelectOpt val]
|
||||
-> C.Source m (Entity val)
|
||||
selectRecursivelySource dir field root filts opts = do
|
||||
srcRes <-
|
||||
liftPersist $ selectRecursivelySourceRes dir field root filts opts
|
||||
(releaseKey, src) <- allocateAcquire srcRes
|
||||
src
|
||||
release releaseKey
|
||||
|
||||
-- | Get the 'Key's of all records in the transitive closure, which match the
|
||||
-- given criterion.
|
||||
selectRecursivelyKeys
|
||||
:: ( PersistQueryRecursive backend
|
||||
, MonadResource m
|
||||
, PersistEntity val
|
||||
, backend ~ PersistEntityBackend val
|
||||
, MonadReader env m
|
||||
, HasPersistBackend env backend
|
||||
)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [SelectOpt val]
|
||||
-> C.Source m (Key val)
|
||||
selectRecursivelyKeys dir field root filts opts = do
|
||||
srcRes <- liftPersist $ selectRecursivelyKeysRes dir field root filts opts
|
||||
(releaseKey, src) <- allocateAcquire srcRes
|
||||
src
|
||||
release releaseKey
|
||||
|
||||
-- | Call 'selectRecursivelySource' but return the result as a list.
|
||||
selectRecursivelyList
|
||||
:: ( PersistQueryRecursive backend
|
||||
, MonadIO m
|
||||
, PersistEntity val
|
||||
, PersistEntityBackend val ~ backend
|
||||
)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [SelectOpt val]
|
||||
-> ReaderT backend m [Entity val]
|
||||
selectRecursivelyList dir field root filts opts = do
|
||||
srcRes <- selectRecursivelySourceRes dir field root filts opts
|
||||
liftIO $ with srcRes (C.$$ CL.consume)
|
||||
|
||||
-- | Call 'selectRecursivelyKeys' but return the result as a list.
|
||||
selectRecursivelyKeysList
|
||||
:: ( PersistQueryRecursive backend
|
||||
, MonadIO m
|
||||
, PersistEntity val
|
||||
, PersistEntityBackend val ~ backend
|
||||
)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [SelectOpt val]
|
||||
-> ReaderT backend m [Key val]
|
||||
selectRecursivelyKeysList dir field root filts opts = do
|
||||
srcRes <- selectRecursivelyKeysRes dir field root filts opts
|
||||
liftIO $ with srcRes (C.$$ CL.consume)
|
239
src/Database/Persist/Local/Sql/Orphan/Common.hs
Normal file
239
src/Database/Persist/Local/Sql/Orphan/Common.hs
Normal file
|
@ -0,0 +1,239 @@
|
|||
{- This file contains (slightly modified) copies of unexported functions from
|
||||
- Database.Persist.Sql.Orphan.PersistQuery, which I need for my
|
||||
- PersistQueryRecursive implementation. They're released under MIT.
|
||||
-
|
||||
- This should be a temporary situation. Either my code moves to persistent and
|
||||
- the functions are reused there, or these functions become exported in
|
||||
- persistent and then I can import them instead of holding copies.
|
||||
-}
|
||||
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
module Database.Persist.Local.Sql.Orphan.Common
|
||||
( fieldName
|
||||
, dummyFromFilts
|
||||
, getFiltsValues
|
||||
, updatePersistValue
|
||||
, filterClause
|
||||
, orderClause
|
||||
)
|
||||
where
|
||||
|
||||
import Prelude
|
||||
|
||||
import Data.List (inits, transpose)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Text (Text)
|
||||
import Database.Persist
|
||||
import Database.Persist.Sql
|
||||
import Database.Persist.Sql.Util
|
||||
|
||||
import qualified Data.Text as T
|
||||
|
||||
fieldName
|
||||
:: forall record typ.
|
||||
(PersistEntity record , PersistEntityBackend record ~ SqlBackend)
|
||||
=> EntityField record typ
|
||||
-> DBName
|
||||
fieldName f = fieldDB $ persistFieldDef f
|
||||
|
||||
dummyFromFilts :: [Filter v] -> Maybe v
|
||||
dummyFromFilts _ = Nothing
|
||||
|
||||
getFiltsValues
|
||||
:: forall val. (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||
=> SqlBackend
|
||||
-> [Filter val]
|
||||
-> [PersistValue]
|
||||
getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo
|
||||
|
||||
data OrNull = OrNullYes | OrNullNo
|
||||
|
||||
filterClauseHelper
|
||||
:: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||
=> Bool -- ^ include table name?
|
||||
-> Bool -- ^ include WHERE?
|
||||
-> SqlBackend
|
||||
-> OrNull
|
||||
-> [Filter val]
|
||||
-> (Text, [PersistValue])
|
||||
filterClauseHelper includeTable includeWhere conn orNull filters =
|
||||
( if not (T.null sql) && includeWhere
|
||||
then " WHERE " <> sql
|
||||
else sql
|
||||
, vals
|
||||
)
|
||||
where
|
||||
(sql, vals) = combineAND filters
|
||||
combineAND = combine " AND "
|
||||
|
||||
combine s fs =
|
||||
(T.intercalate s $ map wrapP a, mconcat b)
|
||||
where
|
||||
(a, b) = unzip $ map go fs
|
||||
wrapP x = T.concat ["(", x, ")"]
|
||||
|
||||
go (BackendFilter _) = error "BackendFilter not expected"
|
||||
go (FilterAnd []) = ("1=1", [])
|
||||
go (FilterAnd fs) = combineAND fs
|
||||
go (FilterOr []) = ("1=0", [])
|
||||
go (FilterOr fs) = combine " OR " fs
|
||||
go (Filter field value pfilter) =
|
||||
let t = entityDef $ dummyFromFilts [Filter field value pfilter]
|
||||
in case (isIdField field, entityPrimary t, allVals) of
|
||||
(True, Just pdef, PersistList ys:_) ->
|
||||
if length (compositeFields pdef) /= length ys
|
||||
then error $ "wrong number of entries in compositeFields vs PersistList allVals=" ++ show allVals
|
||||
else
|
||||
case (allVals, pfilter, isCompFilter pfilter) of
|
||||
([PersistList xs], Eq, _) ->
|
||||
let sqlcl=T.intercalate " and " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef))
|
||||
in (wrapSql sqlcl,xs)
|
||||
([PersistList xs], Ne, _) ->
|
||||
let sqlcl=T.intercalate " or " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef))
|
||||
in (wrapSql sqlcl,xs)
|
||||
(_, In, _) ->
|
||||
let xxs = transpose (map fromPersistList allVals)
|
||||
sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs)
|
||||
in (wrapSql (T.intercalate " and " (map wrapSql sqls)), concat xxs)
|
||||
(_, NotIn, _) ->
|
||||
let xxs = transpose (map fromPersistList allVals)
|
||||
sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs)
|
||||
in (wrapSql (T.intercalate " or " (map wrapSql sqls)), concat xxs)
|
||||
([PersistList xs], _, True) ->
|
||||
let zs = tail (inits (compositeFields pdef))
|
||||
sql1 = map (\b -> wrapSql (T.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs
|
||||
sql2 islast a = connEscapeName conn (fieldDB a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? "
|
||||
sqlcl = T.intercalate " or " sql1
|
||||
in (wrapSql sqlcl, concat (tail (inits xs)))
|
||||
(_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys"
|
||||
_ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals
|
||||
(True, Just pdef, _) -> error $ "unhandled error for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef
|
||||
|
||||
_ -> case (isNull, pfilter, varCount) of
|
||||
(True, Eq, _) -> (name <> " IS NULL", [])
|
||||
(True, Ne, _) -> (name <> " IS NOT NULL", [])
|
||||
(False, Ne, _) -> (T.concat
|
||||
[ "("
|
||||
, name
|
||||
, " IS NULL OR "
|
||||
, name
|
||||
, " <> "
|
||||
, qmarks
|
||||
, ")"
|
||||
], notNullVals)
|
||||
-- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since
|
||||
-- not all databases support those words directly.
|
||||
(_, In, 0) -> ("1=2" <> orNullSuffix, [])
|
||||
(False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals)
|
||||
(True, In, _) -> (T.concat
|
||||
[ "("
|
||||
, name
|
||||
, " IS NULL OR "
|
||||
, name
|
||||
, " IN "
|
||||
, qmarks
|
||||
, ")"
|
||||
], notNullVals)
|
||||
(_, NotIn, 0) -> ("1=1", [])
|
||||
(False, NotIn, _) -> (T.concat
|
||||
[ "("
|
||||
, name
|
||||
, " IS NULL OR "
|
||||
, name
|
||||
, " NOT IN "
|
||||
, qmarks
|
||||
, ")"
|
||||
], notNullVals)
|
||||
(True, NotIn, _) -> (T.concat
|
||||
[ "("
|
||||
, name
|
||||
, " IS NOT NULL AND "
|
||||
, name
|
||||
, " NOT IN "
|
||||
, qmarks
|
||||
, ")"
|
||||
], notNullVals)
|
||||
_ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals)
|
||||
|
||||
where
|
||||
isCompFilter Lt = True
|
||||
isCompFilter Le = True
|
||||
isCompFilter Gt = True
|
||||
isCompFilter Ge = True
|
||||
isCompFilter _ = False
|
||||
|
||||
wrapSql sqlcl = "(" <> sqlcl <> ")"
|
||||
fromPersistList (PersistList xs) = xs
|
||||
fromPersistList other = error $ "expected PersistList but found " ++ show other
|
||||
|
||||
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
|
||||
filterValueToPersistValues v = map toPersistValue $ either return id v
|
||||
|
||||
orNullSuffix =
|
||||
case orNull of
|
||||
OrNullYes -> mconcat [" OR ", name, " IS NULL"]
|
||||
OrNullNo -> ""
|
||||
|
||||
isNull = any (== PersistNull) allVals
|
||||
notNullVals = filter (/= PersistNull) allVals
|
||||
allVals = filterValueToPersistValues value
|
||||
tn = connEscapeName conn $ entityDB
|
||||
$ entityDef $ dummyFromFilts [Filter field value pfilter]
|
||||
name =
|
||||
(if includeTable
|
||||
then ((tn <> ".") <>)
|
||||
else id)
|
||||
$ connEscapeName conn $ fieldName field
|
||||
qmarks = case value of
|
||||
Left _ -> "?"
|
||||
Right x ->
|
||||
let x' = filter (/= PersistNull) $ map toPersistValue x
|
||||
in "(" <> T.intercalate "," (map (const "?") x') <> ")"
|
||||
varCount = case value of
|
||||
Left _ -> 1
|
||||
Right x -> length x
|
||||
showSqlFilter Eq = "="
|
||||
showSqlFilter Ne = "<>"
|
||||
showSqlFilter Gt = ">"
|
||||
showSqlFilter Lt = "<"
|
||||
showSqlFilter Ge = ">="
|
||||
showSqlFilter Le = "<="
|
||||
showSqlFilter In = " IN "
|
||||
showSqlFilter NotIn = " NOT IN "
|
||||
showSqlFilter (BackendSpecificFilter s) = s
|
||||
|
||||
updatePersistValue :: Update v -> PersistValue
|
||||
updatePersistValue (Update _ v _) = toPersistValue v
|
||||
updatePersistValue _ = error "BackendUpdate not implemented"
|
||||
|
||||
filterClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||
=> Bool -- ^ include table name?
|
||||
-> SqlBackend
|
||||
-> [Filter val]
|
||||
-> Text
|
||||
filterClause b c = fst . filterClauseHelper b True c OrNullNo
|
||||
|
||||
orderClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||
=> Bool -- ^ include the table name
|
||||
-> SqlBackend
|
||||
-> SelectOpt val
|
||||
-> Text
|
||||
orderClause includeTable conn o =
|
||||
case o of
|
||||
Asc x -> name x
|
||||
Desc x -> name x <> " DESC"
|
||||
_ -> error "orderClause: expected Asc or Desc, not limit or offset"
|
||||
where
|
||||
dummyFromOrder :: SelectOpt a -> Maybe a
|
||||
dummyFromOrder _ = Nothing
|
||||
|
||||
tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o
|
||||
|
||||
name :: (PersistEntityBackend record ~ SqlBackend, PersistEntity record)
|
||||
=> EntityField record typ -> Text
|
||||
name x =
|
||||
(if includeTable
|
||||
then ((tn <> ".") <>)
|
||||
else id)
|
||||
$ connEscapeName conn $ fieldName x
|
333
src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs
Normal file
333
src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs
Normal file
|
@ -0,0 +1,333 @@
|
|||
{- This file is part of Vervis.
|
||||
-
|
||||
- Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
|
||||
-
|
||||
- ♡ Copying is an act of love. Please copy, reuse and share.
|
||||
-
|
||||
- The author(s) have dedicated all copyright and related and neighboring
|
||||
- rights to this software to the public domain worldwide. This software is
|
||||
- distributed without any warranty.
|
||||
-
|
||||
- You should have received a copy of the CC0 Public Domain Dedication along
|
||||
- with this software. If not, see
|
||||
- <http://creativecommons.org/publicdomain/zero/1.0/>.
|
||||
-}
|
||||
|
||||
module Database.Persist.Local.Sql.Orphan.PersistQueryRecursive
|
||||
( deleteRecursivelyWhereCount
|
||||
, updateRecursivelyWhereCount
|
||||
)
|
||||
where
|
||||
|
||||
import Prelude
|
||||
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Trans.Reader (ReaderT, ask)
|
||||
import Control.Exception (throwIO)
|
||||
import Data.ByteString.Char8 (readInteger)
|
||||
import Data.Conduit (($=))
|
||||
import Data.Foldable (find)
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (isJust)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Text (Text)
|
||||
import Database.Persist
|
||||
import Database.Persist.Sql
|
||||
import Database.Persist.Sql.Util
|
||||
|
||||
import qualified Data.Conduit.List as CL (head, mapM)
|
||||
import qualified Data.Text as T (pack, unpack, intercalate)
|
||||
|
||||
import Database.Persist.Local.Class.PersistQueryRecursive
|
||||
import Database.Persist.Local.Sql.Orphan.Common
|
||||
|
||||
instance PersistQueryRecursive SqlBackend where
|
||||
updateRecursivelyWhere dir field root filts upds =
|
||||
void $ updateRecursivelyWhereCount dir field root filts upds
|
||||
|
||||
deleteRecursivelyWhere dir field root filts =
|
||||
void $ deleteRecursivelyWhereCount dir field root filts
|
||||
|
||||
selectRecursivelySourceRes dir field root filts opts = do
|
||||
conn <- ask
|
||||
let (sql, vals, parse) = sqlValsParse conn
|
||||
srcRes <- rawQueryRes sql vals
|
||||
return $ fmap ($= CL.mapM parse) srcRes
|
||||
where
|
||||
sqlValsParse conn = (sql, vals, parse)
|
||||
where
|
||||
(temp, isRoot, cols, qcols, sqlWith) =
|
||||
withRecursive dir field root conn t (flip entityColumnNames)
|
||||
|
||||
(limit, offset, orders) = limitOffsetOrder opts
|
||||
|
||||
parse xs =
|
||||
case parseEntityValues t xs of
|
||||
Left s -> liftIO $ throwIO $ PersistMarshalError s
|
||||
Right row -> return row
|
||||
t = entityDef $ dummyFromFilts filts
|
||||
wher =
|
||||
if null filts
|
||||
then ""
|
||||
else filterClause False conn filts
|
||||
ord =
|
||||
case map (orderClause False conn) orders of
|
||||
[] -> ""
|
||||
ords -> " ORDER BY " <> T.intercalate "," ords
|
||||
sql =
|
||||
mappend sqlWith $
|
||||
connLimitOffset conn (limit, offset) (not $ null orders) $
|
||||
mconcat
|
||||
[ "SELECT "
|
||||
, cols
|
||||
, " FROM "
|
||||
, connEscapeName conn temp
|
||||
, wher
|
||||
, ord
|
||||
]
|
||||
vals = getFiltsValues conn $ isRoot : filts
|
||||
|
||||
selectRecursivelyKeysRes dir field root filts opts = do
|
||||
conn <- ask
|
||||
let (sql, vals, parse) = sqlValsParse conn
|
||||
srcRes <- rawQueryRes sql vals
|
||||
return $ fmap ($= CL.mapM parse) srcRes
|
||||
where
|
||||
sqlValsParse conn = (sql, vals, parse)
|
||||
where
|
||||
(temp, isRoot, cols, qcols, sqlWith) =
|
||||
withRecursive dir field root conn t dbIdColumns
|
||||
|
||||
(limit, offset, orders) = limitOffsetOrder opts
|
||||
|
||||
parse xs = do
|
||||
keyvals <-
|
||||
case entityPrimary t of
|
||||
Nothing ->
|
||||
case xs of
|
||||
[PersistInt64 x] ->
|
||||
return [PersistInt64 x]
|
||||
[PersistDouble x] ->
|
||||
-- oracle returns Double
|
||||
return [PersistInt64 $ truncate x]
|
||||
_ ->
|
||||
liftIO $ throwIO $ PersistMarshalError $
|
||||
"Unexpected in selectKeys False: " <>
|
||||
T.pack (show xs)
|
||||
Just pdef ->
|
||||
let pks = map fieldHaskell $ compositeFields pdef
|
||||
keyvals =
|
||||
map snd $
|
||||
filter
|
||||
(\ (a, _) ->
|
||||
let ret = isJust (find (== a) pks)
|
||||
in ret
|
||||
) $
|
||||
zip (map fieldHaskell $ entityFields t) xs
|
||||
in return keyvals
|
||||
case keyFromValues keyvals of
|
||||
Right k -> return k
|
||||
Left _ -> error "selectKeysImpl: keyFromValues failed"
|
||||
t = entityDef $ dummyFromFilts filts
|
||||
wher =
|
||||
if null filts
|
||||
then ""
|
||||
else filterClause False conn filts
|
||||
ord =
|
||||
case map (orderClause False conn) orders of
|
||||
[] -> ""
|
||||
ords -> " ORDER BY " <> T.intercalate "," ords
|
||||
sql =
|
||||
mappend sqlWith $
|
||||
connLimitOffset conn (limit, offset) (not $ null orders) $
|
||||
mconcat
|
||||
[ "SELECT "
|
||||
, cols
|
||||
, " FROM "
|
||||
, connEscapeName conn temp
|
||||
, wher
|
||||
, ord
|
||||
]
|
||||
vals = getFiltsValues conn $ isRoot : filts
|
||||
|
||||
countRecursively dir field root filts = do
|
||||
conn <- ask
|
||||
let (sql, vals) = sqlAndVals conn
|
||||
withRawQuery sql vals $ do
|
||||
mm <- CL.head
|
||||
case mm of
|
||||
Just [PersistInt64 i] -> return $ fromIntegral i
|
||||
Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64) -- gb oracle
|
||||
Just [PersistByteString i] -> case readInteger i of -- gb mssql
|
||||
Just (ret,"") -> return $ fromIntegral ret
|
||||
xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]"
|
||||
Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]"
|
||||
Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]"
|
||||
where
|
||||
sqlAndVals conn = (sql, vals)
|
||||
where
|
||||
(temp, isRoot, cols, qcols, sqlWith) =
|
||||
withRecursive dir field root conn t dbIdColumns
|
||||
|
||||
t = entityDef $ dummyFromFilts filts
|
||||
wher =
|
||||
if null filts
|
||||
then ""
|
||||
else filterClause False conn filts
|
||||
sql = mconcat
|
||||
[ sqlWith
|
||||
, "SELECT COUNT(*) FROM "
|
||||
, connEscapeName conn temp
|
||||
, wher
|
||||
]
|
||||
vals = getFiltsValues conn $ isRoot : filts
|
||||
|
||||
-- | Same as 'deleteRecursivelyWhere', but returns the number of rows affected.
|
||||
deleteRecursivelyWhereCount
|
||||
:: (PersistEntity val, MonadIO m, PersistEntityBackend val ~ SqlBackend)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> ReaderT SqlBackend m Int64
|
||||
deleteRecursivelyWhereCount dir field root filts = do
|
||||
conn <- ask
|
||||
let (sql, vals) = sqlAndVals conn
|
||||
rawExecuteCount sql vals
|
||||
where
|
||||
sqlAndVals conn = (sql, vals)
|
||||
where
|
||||
(temp, isRoot, cols, qcols, sqlWith) =
|
||||
withRecursive dir field root conn t dbIdColumns
|
||||
|
||||
t = entityDef $ dummyFromFilts filts
|
||||
wher = mconcat
|
||||
[ if null filts
|
||||
then " WHERE ( "
|
||||
else filterClause False conn filts <> " AND ( "
|
||||
, connEscapeName conn $ fieldDB $ entityId t
|
||||
, " IN (SELECT "
|
||||
, connEscapeName conn $ fieldDB $ entityId t
|
||||
, " FROM "
|
||||
, connEscapeName conn temp
|
||||
, ") ) "
|
||||
]
|
||||
sql = mconcat
|
||||
[ sqlWith
|
||||
, "DELETE FROM "
|
||||
, connEscapeName conn $ entityDB t
|
||||
, wher
|
||||
]
|
||||
vals = getFiltsValues conn $ isRoot : filts
|
||||
|
||||
-- | Same as 'updateRecursivelyWhere', but returns the number of rows affected.
|
||||
updateRecursivelyWhereCount
|
||||
:: (PersistEntity val, MonadIO m, SqlBackend ~ PersistEntityBackend val)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> [Filter val]
|
||||
-> [Update val]
|
||||
-> ReaderT SqlBackend m Int64
|
||||
updateRecursivelyWhereCount _ _ _ _ [] = return 0
|
||||
updateRecursivelyWhereCount dir field root filts upds = do
|
||||
conn <- ask
|
||||
let (sql, vals) = sqlAndVals conn
|
||||
rawExecuteCount sql vals
|
||||
where
|
||||
sqlAndVals conn = (sql, vals)
|
||||
where
|
||||
(temp, isRoot, cols, qcols, sqlWith) =
|
||||
withRecursive dir field root conn t dbIdColumns
|
||||
|
||||
t = entityDef $ dummyFromFilts filts
|
||||
|
||||
go'' n Assign = n <> "=?"
|
||||
go'' n Add = mconcat [n, "=", n, "+?"]
|
||||
go'' n Subtract = mconcat [n, "=", n, "-?"]
|
||||
go'' n Multiply = mconcat [n, "=", n, "*?"]
|
||||
go'' n Divide = mconcat [n, "=", n, "/?"]
|
||||
go'' _ (BackendSpecificUpdate up) =
|
||||
error $ T.unpack $ "BackendSpecificUpdate " <> up <> " not supported"
|
||||
go' (x, pu) = go'' (connEscapeName conn x) pu
|
||||
go x = (updateField x, updateUpdate x)
|
||||
|
||||
updateField (Update f _ _) = fieldName f
|
||||
updateField _ = error "BackendUpdate not implemented"
|
||||
|
||||
wher = mconcat
|
||||
[ if null filts
|
||||
then " WHERE ( "
|
||||
else filterClause False conn filts <> " AND ( "
|
||||
, connEscapeName conn $ fieldDB $ entityId t
|
||||
, " IN (SELECT "
|
||||
, connEscapeName conn $ fieldDB $ entityId t
|
||||
, " FROM "
|
||||
, connEscapeName conn temp
|
||||
, ") ) "
|
||||
]
|
||||
sql = mconcat
|
||||
[ sqlWith
|
||||
, "UPDATE "
|
||||
, connEscapeName conn $ entityDB t
|
||||
, " SET "
|
||||
, T.intercalate "," $ map (go' . go) upds
|
||||
, wher
|
||||
]
|
||||
vals =
|
||||
getFiltsValues conn [isRoot] ++
|
||||
map updatePersistValue upds ++
|
||||
getFiltsValues conn filts
|
||||
|
||||
withRecursive
|
||||
:: (PersistEntity val, SqlBackend ~ PersistEntityBackend val)
|
||||
=> RecursionDirection
|
||||
-> EntityField val (Maybe (Key val))
|
||||
-> Key val
|
||||
-> SqlBackend
|
||||
-> EntityDef
|
||||
-> (SqlBackend -> EntityDef -> [Text])
|
||||
-> (DBName, Filter val, Text, DBName -> Text, Text)
|
||||
withRecursive dir field root conn t getcols =
|
||||
let temp = DBName "temp_hierarchy_cte"
|
||||
isRoot = persistIdField ==. root
|
||||
cols = T.intercalate "," $ getcols conn t
|
||||
qcols name =
|
||||
T.intercalate ", " $
|
||||
map ((connEscapeName conn name <>) . ("." <>)) $
|
||||
getcols conn t
|
||||
sql = mconcat
|
||||
[ "WITH RECURSIVE "
|
||||
, connEscapeName conn temp
|
||||
, "("
|
||||
, cols
|
||||
, ") AS ( SELECT "
|
||||
, cols
|
||||
, " FROM "
|
||||
, connEscapeName conn $ entityDB t
|
||||
, filterClause False conn [isRoot]
|
||||
--, " WHERE "
|
||||
--, connEscapeName conn $ fieldDB $ entityId t
|
||||
--, " = ?"
|
||||
, " UNION SELECT "
|
||||
, qcols temp
|
||||
, " FROM "
|
||||
, connEscapeName conn $ entityDB t
|
||||
, ", "
|
||||
, connEscapeName conn temp
|
||||
, " WHERE "
|
||||
, connEscapeName conn $ entityDB t
|
||||
, "."
|
||||
, connEscapeName conn $ fieldDB $ case dir of
|
||||
RecOut -> persistFieldDef field
|
||||
RecIn -> entityId t
|
||||
, " = "
|
||||
, connEscapeName conn temp
|
||||
, "."
|
||||
, connEscapeName conn $ fieldDB $ case dir of
|
||||
RecOut -> entityId t
|
||||
RecIn -> persistFieldDef field
|
||||
, " ) "
|
||||
]
|
||||
in (temp, isRoot, cols, qcols, sql)
|
|
@ -64,6 +64,9 @@ library
|
|||
Database.Esqueleto.Local
|
||||
Database.Persist.Class.Local
|
||||
Database.Persist.Sql.Local
|
||||
Database.Persist.Local.Class.PersistQueryRecursive
|
||||
Database.Persist.Local.Sql.Orphan.Common
|
||||
Database.Persist.Local.Sql.Orphan.PersistQueryRecursive
|
||||
Development.DarcsRev
|
||||
Formatting.CaseInsensitive
|
||||
Network.SSH.Local
|
||||
|
@ -219,6 +222,8 @@ library
|
|||
, memory
|
||||
, monad-control
|
||||
, monad-logger
|
||||
-- for Database.Persist.Local
|
||||
, mtl
|
||||
, pandoc
|
||||
, pandoc-types
|
||||
-- for PathPiece instance for CI, Web.PathPieces.Local
|
||||
|
@ -227,6 +232,8 @@ library
|
|||
, persistent-postgresql
|
||||
, persistent-template
|
||||
, process
|
||||
-- for Database.Persist.Local
|
||||
, resourcet
|
||||
, safe
|
||||
, shakespeare
|
||||
, ssh
|
||||
|
|
Loading…
Reference in a new issue