Recursive SQL queries, still need to test before use
This commit is contained in:
parent
55945e30f9
commit
76a627385c
4 changed files with 834 additions and 0 deletions
255
src/Database/Persist/Local/Class/PersistQueryRecursive.hs
Normal file
255
src/Database/Persist/Local/Class/PersistQueryRecursive.hs
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
{- This file is part of Vervis.
|
||||||
|
-
|
||||||
|
- Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
|
||||||
|
-
|
||||||
|
- ♡ Copying is an act of love. Please copy, reuse and share.
|
||||||
|
-
|
||||||
|
- The author(s) have dedicated all copyright and related and neighboring
|
||||||
|
- rights to this software to the public domain worldwide. This software is
|
||||||
|
- distributed without any warranty.
|
||||||
|
-
|
||||||
|
- You should have received a copy of the CC0 Public Domain Dedication along
|
||||||
|
- with this software. If not, see
|
||||||
|
- <http://creativecommons.org/publicdomain/zero/1.0/>.
|
||||||
|
-}
|
||||||
|
|
||||||
|
{- The code is based on PersistQuery. Actually, most of the difference is
|
||||||
|
- slightly different names and 3 additional function parameters.
|
||||||
|
-}
|
||||||
|
|
||||||
|
-- | Recursive queries are performed by taking the output of a recursion step,
|
||||||
|
-- possibly modifying it, and using the result as the input for the next
|
||||||
|
-- recursion step.
|
||||||
|
--
|
||||||
|
-- This module currently provides a single way to perform that recursive step:
|
||||||
|
-- Match between the /id/ column and some other column which has the same type.
|
||||||
|
--
|
||||||
|
-- For example, suppose we have a `Message` type with a `messageParent` field.
|
||||||
|
-- For given messages `a` and `b`, if `messageParent b == Just a` then `b` is a
|
||||||
|
-- reply to `a`. Therefore, all the replies to a given message point to it
|
||||||
|
-- using the `messageParent` field. And there can be replies to replies and so
|
||||||
|
-- on, creating a tree of messages.
|
||||||
|
--
|
||||||
|
-- > Message
|
||||||
|
-- > author PersonId
|
||||||
|
-- > content Text
|
||||||
|
-- > parent MessageId Maybe
|
||||||
|
--
|
||||||
|
-- If we start with a single message and follow the `messageParent` values
|
||||||
|
-- recursively, we'll be able to get a list (or a tree) of the __ancestors__ of
|
||||||
|
-- the message. Our message /a/ may be a reply to some other message /b/, and
|
||||||
|
-- /b/ may be a reply to message /c/ and so on. Eventually, if there are no
|
||||||
|
-- cycles and it's really a tree structure, we'll reach the root message, which
|
||||||
|
-- has no parent.
|
||||||
|
--
|
||||||
|
-- But there's another way to recurse. What if we wanted to find the replies
|
||||||
|
-- for a given message? And the replies of the replies, and so on? In other
|
||||||
|
-- words, the __decendants__ of a given message. Suppose we start with a
|
||||||
|
-- message /a/. We get a list of the replies of /a/, i.e. message whose parent
|
||||||
|
-- is `Just a`. Then we find the replies of those messages, i.e. the replies of
|
||||||
|
-- the replies of /a/. And so on, recursively, until we can't find more replies
|
||||||
|
-- and then we stop.
|
||||||
|
--
|
||||||
|
-- Therefore we can perform the recursion in one of two directions:
|
||||||
|
--
|
||||||
|
-- - __Outwards__, i.e. follow from a message to its parents. More generally,
|
||||||
|
-- given a persistent entity type `Foobar`, follow recursively using a
|
||||||
|
-- specific field of it, whose type is `FoobarId`. It's called "outwards"
|
||||||
|
-- because it's like following out-edges of a graph node, i.e. arrows
|
||||||
|
-- pointing from a node towards other nodes.
|
||||||
|
-- - __Inwards__, i.e. find the children (i.e. replies) of a message, and then
|
||||||
|
-- their children, and so on. More generally, given a persistent entity type
|
||||||
|
-- `Foobar`, find other values referring to it using a specific field, whose
|
||||||
|
-- type is `FoobarId`, and recursive find such values for the results we get
|
||||||
|
-- and so on. It's called "inwards" because it's like following in-edges of a
|
||||||
|
-- graph node, i.e. arrows pointing from other nodes towards that node.
|
||||||
|
--
|
||||||
|
-- The 'RecursionDirection' type is used for specifying the direction.
|
||||||
|
--
|
||||||
|
-- When you follow all the children of an entity recursively, or all of its
|
||||||
|
-- parents, we call the result you get the __transitive closure__ of the
|
||||||
|
-- specific field you used. You can further specify the direction, i.e.
|
||||||
|
-- __outward transitive closure__ or __inward transitive closure__. For
|
||||||
|
-- examples, if you follow a message's parents recursively as in the example
|
||||||
|
-- above, you get an outward transitive closure on the /parent/ field.
|
||||||
|
--
|
||||||
|
-- Note that the definition used here is __not__ the same as the mathematical
|
||||||
|
-- definition. When you perform a recursive query without filters, you get not
|
||||||
|
-- only the ancestors (or the decendants) of an entity, but also the root
|
||||||
|
-- entity itself. In other words, even though a message is not a reply of
|
||||||
|
-- itself, you'll still get it in the query result. If you want to get just the
|
||||||
|
-- ancestors (or decendants), i.e. the actual transitive closer of the "is
|
||||||
|
-- reply of" relation in the mathematical sense, use a filter to omit the root
|
||||||
|
-- message based on the ID, i.e. @[MessageParent /= msgid]@.
|
||||||
|
--
|
||||||
|
-- Therefore, when the term "transitive closure" is used below, it means not
|
||||||
|
-- just the ancestors (or decendants), but also the origin entity too.
|
||||||
|
module Database.Persist.Local.Class.PersistQueryRecursive
|
||||||
|
( RecursionDirection (..)
|
||||||
|
, PersistQueryRecursive (..)
|
||||||
|
, selectRecursivelySource
|
||||||
|
, selectRecursivelyKeys
|
||||||
|
, selectRecursivelyList
|
||||||
|
, selectRecursivelyKeysList
|
||||||
|
)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Prelude
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class
|
||||||
|
import Control.Monad.Reader (MonadReader)
|
||||||
|
import Control.Monad.Trans.Reader (ReaderT)
|
||||||
|
import Control.Monad.Trans.Resource (MonadResource, release)
|
||||||
|
import Data.Acquire (Acquire, allocateAcquire, with)
|
||||||
|
import Database.Persist.Class
|
||||||
|
import Database.Persist.Types
|
||||||
|
|
||||||
|
import qualified Data.Conduit as C
|
||||||
|
import qualified Data.Conduit.List as CL
|
||||||
|
|
||||||
|
data RecursionDirection
|
||||||
|
= RecOut
|
||||||
|
| RecIn
|
||||||
|
deriving (Eq, Show)
|
||||||
|
|
||||||
|
-- | Backends supporting recursive conditional operations.
|
||||||
|
class PersistQuery backend => PersistQueryRecursive backend where
|
||||||
|
-- | Update individual fields on any record in the transitive closure and
|
||||||
|
-- matching the given criterion.
|
||||||
|
updateRecursivelyWhere
|
||||||
|
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [Update val]
|
||||||
|
-> ReaderT backend m ()
|
||||||
|
|
||||||
|
-- | Delete all records in the transitive closure which match the given
|
||||||
|
-- criterion.
|
||||||
|
deleteRecursivelyWhere
|
||||||
|
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> ReaderT backend m ()
|
||||||
|
|
||||||
|
-- | Get all records in the transitive closure, which match the given
|
||||||
|
-- criterion, in the specified order. Returns also the identifiers.
|
||||||
|
selectRecursivelySourceRes
|
||||||
|
:: ( PersistEntity val
|
||||||
|
, PersistEntityBackend val ~ backend
|
||||||
|
, MonadIO m1
|
||||||
|
, MonadIO m2
|
||||||
|
)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [SelectOpt val]
|
||||||
|
-> ReaderT backend m1 (Acquire (C.Source m2 (Entity val)))
|
||||||
|
|
||||||
|
-- | Get the 'Key's of all records in the transitive closure, which match
|
||||||
|
-- the given criterion.
|
||||||
|
selectRecursivelyKeysRes
|
||||||
|
:: ( PersistEntity val
|
||||||
|
, PersistEntityBackend val ~ backend
|
||||||
|
, MonadIO m1
|
||||||
|
, MonadIO m2
|
||||||
|
)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [SelectOpt val]
|
||||||
|
-> ReaderT backend m1 (Acquire (C.Source m2 (Key val)))
|
||||||
|
|
||||||
|
-- | The total number of records in the transitive closure which fulfill
|
||||||
|
-- the given criterion.
|
||||||
|
countRecursively
|
||||||
|
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> ReaderT backend m Int
|
||||||
|
|
||||||
|
-- | Get all records in the transitive closure, which match the given
|
||||||
|
-- criterion, in the specified order. Returns also the identifiers.
|
||||||
|
selectRecursivelySource
|
||||||
|
:: ( PersistQueryRecursive backend
|
||||||
|
, MonadResource m
|
||||||
|
, PersistEntity val
|
||||||
|
, PersistEntityBackend val ~ backend
|
||||||
|
, MonadReader env m
|
||||||
|
, HasPersistBackend env backend
|
||||||
|
)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [SelectOpt val]
|
||||||
|
-> C.Source m (Entity val)
|
||||||
|
selectRecursivelySource dir field root filts opts = do
|
||||||
|
srcRes <-
|
||||||
|
liftPersist $ selectRecursivelySourceRes dir field root filts opts
|
||||||
|
(releaseKey, src) <- allocateAcquire srcRes
|
||||||
|
src
|
||||||
|
release releaseKey
|
||||||
|
|
||||||
|
-- | Get the 'Key's of all records in the transitive closure, which match the
|
||||||
|
-- given criterion.
|
||||||
|
selectRecursivelyKeys
|
||||||
|
:: ( PersistQueryRecursive backend
|
||||||
|
, MonadResource m
|
||||||
|
, PersistEntity val
|
||||||
|
, backend ~ PersistEntityBackend val
|
||||||
|
, MonadReader env m
|
||||||
|
, HasPersistBackend env backend
|
||||||
|
)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [SelectOpt val]
|
||||||
|
-> C.Source m (Key val)
|
||||||
|
selectRecursivelyKeys dir field root filts opts = do
|
||||||
|
srcRes <- liftPersist $ selectRecursivelyKeysRes dir field root filts opts
|
||||||
|
(releaseKey, src) <- allocateAcquire srcRes
|
||||||
|
src
|
||||||
|
release releaseKey
|
||||||
|
|
||||||
|
-- | Call 'selectRecursivelySource' but return the result as a list.
|
||||||
|
selectRecursivelyList
|
||||||
|
:: ( PersistQueryRecursive backend
|
||||||
|
, MonadIO m
|
||||||
|
, PersistEntity val
|
||||||
|
, PersistEntityBackend val ~ backend
|
||||||
|
)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [SelectOpt val]
|
||||||
|
-> ReaderT backend m [Entity val]
|
||||||
|
selectRecursivelyList dir field root filts opts = do
|
||||||
|
srcRes <- selectRecursivelySourceRes dir field root filts opts
|
||||||
|
liftIO $ with srcRes (C.$$ CL.consume)
|
||||||
|
|
||||||
|
-- | Call 'selectRecursivelyKeys' but return the result as a list.
|
||||||
|
selectRecursivelyKeysList
|
||||||
|
:: ( PersistQueryRecursive backend
|
||||||
|
, MonadIO m
|
||||||
|
, PersistEntity val
|
||||||
|
, PersistEntityBackend val ~ backend
|
||||||
|
)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [SelectOpt val]
|
||||||
|
-> ReaderT backend m [Key val]
|
||||||
|
selectRecursivelyKeysList dir field root filts opts = do
|
||||||
|
srcRes <- selectRecursivelyKeysRes dir field root filts opts
|
||||||
|
liftIO $ with srcRes (C.$$ CL.consume)
|
239
src/Database/Persist/Local/Sql/Orphan/Common.hs
Normal file
239
src/Database/Persist/Local/Sql/Orphan/Common.hs
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
{- This file contains (slightly modified) copies of unexported functions from
|
||||||
|
- Database.Persist.Sql.Orphan.PersistQuery, which I need for my
|
||||||
|
- PersistQueryRecursive implementation. They're released under MIT.
|
||||||
|
-
|
||||||
|
- This should be a temporary situation. Either my code moves to persistent and
|
||||||
|
- the functions are reused there, or these functions become exported in
|
||||||
|
- persistent and then I can import them instead of holding copies.
|
||||||
|
-}
|
||||||
|
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
|
module Database.Persist.Local.Sql.Orphan.Common
|
||||||
|
( fieldName
|
||||||
|
, dummyFromFilts
|
||||||
|
, getFiltsValues
|
||||||
|
, updatePersistValue
|
||||||
|
, filterClause
|
||||||
|
, orderClause
|
||||||
|
)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Prelude
|
||||||
|
|
||||||
|
import Data.List (inits, transpose)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Database.Persist
|
||||||
|
import Database.Persist.Sql
|
||||||
|
import Database.Persist.Sql.Util
|
||||||
|
|
||||||
|
import qualified Data.Text as T
|
||||||
|
|
||||||
|
fieldName
|
||||||
|
:: forall record typ.
|
||||||
|
(PersistEntity record , PersistEntityBackend record ~ SqlBackend)
|
||||||
|
=> EntityField record typ
|
||||||
|
-> DBName
|
||||||
|
fieldName f = fieldDB $ persistFieldDef f
|
||||||
|
|
||||||
|
dummyFromFilts :: [Filter v] -> Maybe v
|
||||||
|
dummyFromFilts _ = Nothing
|
||||||
|
|
||||||
|
getFiltsValues
|
||||||
|
:: forall val. (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||||
|
=> SqlBackend
|
||||||
|
-> [Filter val]
|
||||||
|
-> [PersistValue]
|
||||||
|
getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo
|
||||||
|
|
||||||
|
data OrNull = OrNullYes | OrNullNo
|
||||||
|
|
||||||
|
filterClauseHelper
|
||||||
|
:: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||||
|
=> Bool -- ^ include table name?
|
||||||
|
-> Bool -- ^ include WHERE?
|
||||||
|
-> SqlBackend
|
||||||
|
-> OrNull
|
||||||
|
-> [Filter val]
|
||||||
|
-> (Text, [PersistValue])
|
||||||
|
filterClauseHelper includeTable includeWhere conn orNull filters =
|
||||||
|
( if not (T.null sql) && includeWhere
|
||||||
|
then " WHERE " <> sql
|
||||||
|
else sql
|
||||||
|
, vals
|
||||||
|
)
|
||||||
|
where
|
||||||
|
(sql, vals) = combineAND filters
|
||||||
|
combineAND = combine " AND "
|
||||||
|
|
||||||
|
combine s fs =
|
||||||
|
(T.intercalate s $ map wrapP a, mconcat b)
|
||||||
|
where
|
||||||
|
(a, b) = unzip $ map go fs
|
||||||
|
wrapP x = T.concat ["(", x, ")"]
|
||||||
|
|
||||||
|
go (BackendFilter _) = error "BackendFilter not expected"
|
||||||
|
go (FilterAnd []) = ("1=1", [])
|
||||||
|
go (FilterAnd fs) = combineAND fs
|
||||||
|
go (FilterOr []) = ("1=0", [])
|
||||||
|
go (FilterOr fs) = combine " OR " fs
|
||||||
|
go (Filter field value pfilter) =
|
||||||
|
let t = entityDef $ dummyFromFilts [Filter field value pfilter]
|
||||||
|
in case (isIdField field, entityPrimary t, allVals) of
|
||||||
|
(True, Just pdef, PersistList ys:_) ->
|
||||||
|
if length (compositeFields pdef) /= length ys
|
||||||
|
then error $ "wrong number of entries in compositeFields vs PersistList allVals=" ++ show allVals
|
||||||
|
else
|
||||||
|
case (allVals, pfilter, isCompFilter pfilter) of
|
||||||
|
([PersistList xs], Eq, _) ->
|
||||||
|
let sqlcl=T.intercalate " and " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef))
|
||||||
|
in (wrapSql sqlcl,xs)
|
||||||
|
([PersistList xs], Ne, _) ->
|
||||||
|
let sqlcl=T.intercalate " or " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef))
|
||||||
|
in (wrapSql sqlcl,xs)
|
||||||
|
(_, In, _) ->
|
||||||
|
let xxs = transpose (map fromPersistList allVals)
|
||||||
|
sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs)
|
||||||
|
in (wrapSql (T.intercalate " and " (map wrapSql sqls)), concat xxs)
|
||||||
|
(_, NotIn, _) ->
|
||||||
|
let xxs = transpose (map fromPersistList allVals)
|
||||||
|
sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs)
|
||||||
|
in (wrapSql (T.intercalate " or " (map wrapSql sqls)), concat xxs)
|
||||||
|
([PersistList xs], _, True) ->
|
||||||
|
let zs = tail (inits (compositeFields pdef))
|
||||||
|
sql1 = map (\b -> wrapSql (T.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs
|
||||||
|
sql2 islast a = connEscapeName conn (fieldDB a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? "
|
||||||
|
sqlcl = T.intercalate " or " sql1
|
||||||
|
in (wrapSql sqlcl, concat (tail (inits xs)))
|
||||||
|
(_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys"
|
||||||
|
_ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals
|
||||||
|
(True, Just pdef, _) -> error $ "unhandled error for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef
|
||||||
|
|
||||||
|
_ -> case (isNull, pfilter, varCount) of
|
||||||
|
(True, Eq, _) -> (name <> " IS NULL", [])
|
||||||
|
(True, Ne, _) -> (name <> " IS NOT NULL", [])
|
||||||
|
(False, Ne, _) -> (T.concat
|
||||||
|
[ "("
|
||||||
|
, name
|
||||||
|
, " IS NULL OR "
|
||||||
|
, name
|
||||||
|
, " <> "
|
||||||
|
, qmarks
|
||||||
|
, ")"
|
||||||
|
], notNullVals)
|
||||||
|
-- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since
|
||||||
|
-- not all databases support those words directly.
|
||||||
|
(_, In, 0) -> ("1=2" <> orNullSuffix, [])
|
||||||
|
(False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals)
|
||||||
|
(True, In, _) -> (T.concat
|
||||||
|
[ "("
|
||||||
|
, name
|
||||||
|
, " IS NULL OR "
|
||||||
|
, name
|
||||||
|
, " IN "
|
||||||
|
, qmarks
|
||||||
|
, ")"
|
||||||
|
], notNullVals)
|
||||||
|
(_, NotIn, 0) -> ("1=1", [])
|
||||||
|
(False, NotIn, _) -> (T.concat
|
||||||
|
[ "("
|
||||||
|
, name
|
||||||
|
, " IS NULL OR "
|
||||||
|
, name
|
||||||
|
, " NOT IN "
|
||||||
|
, qmarks
|
||||||
|
, ")"
|
||||||
|
], notNullVals)
|
||||||
|
(True, NotIn, _) -> (T.concat
|
||||||
|
[ "("
|
||||||
|
, name
|
||||||
|
, " IS NOT NULL AND "
|
||||||
|
, name
|
||||||
|
, " NOT IN "
|
||||||
|
, qmarks
|
||||||
|
, ")"
|
||||||
|
], notNullVals)
|
||||||
|
_ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals)
|
||||||
|
|
||||||
|
where
|
||||||
|
isCompFilter Lt = True
|
||||||
|
isCompFilter Le = True
|
||||||
|
isCompFilter Gt = True
|
||||||
|
isCompFilter Ge = True
|
||||||
|
isCompFilter _ = False
|
||||||
|
|
||||||
|
wrapSql sqlcl = "(" <> sqlcl <> ")"
|
||||||
|
fromPersistList (PersistList xs) = xs
|
||||||
|
fromPersistList other = error $ "expected PersistList but found " ++ show other
|
||||||
|
|
||||||
|
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
|
||||||
|
filterValueToPersistValues v = map toPersistValue $ either return id v
|
||||||
|
|
||||||
|
orNullSuffix =
|
||||||
|
case orNull of
|
||||||
|
OrNullYes -> mconcat [" OR ", name, " IS NULL"]
|
||||||
|
OrNullNo -> ""
|
||||||
|
|
||||||
|
isNull = any (== PersistNull) allVals
|
||||||
|
notNullVals = filter (/= PersistNull) allVals
|
||||||
|
allVals = filterValueToPersistValues value
|
||||||
|
tn = connEscapeName conn $ entityDB
|
||||||
|
$ entityDef $ dummyFromFilts [Filter field value pfilter]
|
||||||
|
name =
|
||||||
|
(if includeTable
|
||||||
|
then ((tn <> ".") <>)
|
||||||
|
else id)
|
||||||
|
$ connEscapeName conn $ fieldName field
|
||||||
|
qmarks = case value of
|
||||||
|
Left _ -> "?"
|
||||||
|
Right x ->
|
||||||
|
let x' = filter (/= PersistNull) $ map toPersistValue x
|
||||||
|
in "(" <> T.intercalate "," (map (const "?") x') <> ")"
|
||||||
|
varCount = case value of
|
||||||
|
Left _ -> 1
|
||||||
|
Right x -> length x
|
||||||
|
showSqlFilter Eq = "="
|
||||||
|
showSqlFilter Ne = "<>"
|
||||||
|
showSqlFilter Gt = ">"
|
||||||
|
showSqlFilter Lt = "<"
|
||||||
|
showSqlFilter Ge = ">="
|
||||||
|
showSqlFilter Le = "<="
|
||||||
|
showSqlFilter In = " IN "
|
||||||
|
showSqlFilter NotIn = " NOT IN "
|
||||||
|
showSqlFilter (BackendSpecificFilter s) = s
|
||||||
|
|
||||||
|
updatePersistValue :: Update v -> PersistValue
|
||||||
|
updatePersistValue (Update _ v _) = toPersistValue v
|
||||||
|
updatePersistValue _ = error "BackendUpdate not implemented"
|
||||||
|
|
||||||
|
filterClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||||
|
=> Bool -- ^ include table name?
|
||||||
|
-> SqlBackend
|
||||||
|
-> [Filter val]
|
||||||
|
-> Text
|
||||||
|
filterClause b c = fst . filterClauseHelper b True c OrNullNo
|
||||||
|
|
||||||
|
orderClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
|
||||||
|
=> Bool -- ^ include the table name
|
||||||
|
-> SqlBackend
|
||||||
|
-> SelectOpt val
|
||||||
|
-> Text
|
||||||
|
orderClause includeTable conn o =
|
||||||
|
case o of
|
||||||
|
Asc x -> name x
|
||||||
|
Desc x -> name x <> " DESC"
|
||||||
|
_ -> error "orderClause: expected Asc or Desc, not limit or offset"
|
||||||
|
where
|
||||||
|
dummyFromOrder :: SelectOpt a -> Maybe a
|
||||||
|
dummyFromOrder _ = Nothing
|
||||||
|
|
||||||
|
tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o
|
||||||
|
|
||||||
|
name :: (PersistEntityBackend record ~ SqlBackend, PersistEntity record)
|
||||||
|
=> EntityField record typ -> Text
|
||||||
|
name x =
|
||||||
|
(if includeTable
|
||||||
|
then ((tn <> ".") <>)
|
||||||
|
else id)
|
||||||
|
$ connEscapeName conn $ fieldName x
|
333
src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs
Normal file
333
src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs
Normal file
|
@ -0,0 +1,333 @@
|
||||||
|
{- This file is part of Vervis.
|
||||||
|
-
|
||||||
|
- Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
|
||||||
|
-
|
||||||
|
- ♡ Copying is an act of love. Please copy, reuse and share.
|
||||||
|
-
|
||||||
|
- The author(s) have dedicated all copyright and related and neighboring
|
||||||
|
- rights to this software to the public domain worldwide. This software is
|
||||||
|
- distributed without any warranty.
|
||||||
|
-
|
||||||
|
- You should have received a copy of the CC0 Public Domain Dedication along
|
||||||
|
- with this software. If not, see
|
||||||
|
- <http://creativecommons.org/publicdomain/zero/1.0/>.
|
||||||
|
-}
|
||||||
|
|
||||||
|
module Database.Persist.Local.Sql.Orphan.PersistQueryRecursive
|
||||||
|
( deleteRecursivelyWhereCount
|
||||||
|
, updateRecursivelyWhereCount
|
||||||
|
)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Prelude
|
||||||
|
|
||||||
|
import Control.Monad (void)
|
||||||
|
import Control.Monad.IO.Class
|
||||||
|
import Control.Monad.Trans.Reader (ReaderT, ask)
|
||||||
|
import Control.Exception (throwIO)
|
||||||
|
import Data.ByteString.Char8 (readInteger)
|
||||||
|
import Data.Conduit (($=))
|
||||||
|
import Data.Foldable (find)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Data.Maybe (isJust)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Database.Persist
|
||||||
|
import Database.Persist.Sql
|
||||||
|
import Database.Persist.Sql.Util
|
||||||
|
|
||||||
|
import qualified Data.Conduit.List as CL (head, mapM)
|
||||||
|
import qualified Data.Text as T (pack, unpack, intercalate)
|
||||||
|
|
||||||
|
import Database.Persist.Local.Class.PersistQueryRecursive
|
||||||
|
import Database.Persist.Local.Sql.Orphan.Common
|
||||||
|
|
||||||
|
instance PersistQueryRecursive SqlBackend where
|
||||||
|
updateRecursivelyWhere dir field root filts upds =
|
||||||
|
void $ updateRecursivelyWhereCount dir field root filts upds
|
||||||
|
|
||||||
|
deleteRecursivelyWhere dir field root filts =
|
||||||
|
void $ deleteRecursivelyWhereCount dir field root filts
|
||||||
|
|
||||||
|
selectRecursivelySourceRes dir field root filts opts = do
|
||||||
|
conn <- ask
|
||||||
|
let (sql, vals, parse) = sqlValsParse conn
|
||||||
|
srcRes <- rawQueryRes sql vals
|
||||||
|
return $ fmap ($= CL.mapM parse) srcRes
|
||||||
|
where
|
||||||
|
sqlValsParse conn = (sql, vals, parse)
|
||||||
|
where
|
||||||
|
(temp, isRoot, cols, qcols, sqlWith) =
|
||||||
|
withRecursive dir field root conn t (flip entityColumnNames)
|
||||||
|
|
||||||
|
(limit, offset, orders) = limitOffsetOrder opts
|
||||||
|
|
||||||
|
parse xs =
|
||||||
|
case parseEntityValues t xs of
|
||||||
|
Left s -> liftIO $ throwIO $ PersistMarshalError s
|
||||||
|
Right row -> return row
|
||||||
|
t = entityDef $ dummyFromFilts filts
|
||||||
|
wher =
|
||||||
|
if null filts
|
||||||
|
then ""
|
||||||
|
else filterClause False conn filts
|
||||||
|
ord =
|
||||||
|
case map (orderClause False conn) orders of
|
||||||
|
[] -> ""
|
||||||
|
ords -> " ORDER BY " <> T.intercalate "," ords
|
||||||
|
sql =
|
||||||
|
mappend sqlWith $
|
||||||
|
connLimitOffset conn (limit, offset) (not $ null orders) $
|
||||||
|
mconcat
|
||||||
|
[ "SELECT "
|
||||||
|
, cols
|
||||||
|
, " FROM "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, wher
|
||||||
|
, ord
|
||||||
|
]
|
||||||
|
vals = getFiltsValues conn $ isRoot : filts
|
||||||
|
|
||||||
|
selectRecursivelyKeysRes dir field root filts opts = do
|
||||||
|
conn <- ask
|
||||||
|
let (sql, vals, parse) = sqlValsParse conn
|
||||||
|
srcRes <- rawQueryRes sql vals
|
||||||
|
return $ fmap ($= CL.mapM parse) srcRes
|
||||||
|
where
|
||||||
|
sqlValsParse conn = (sql, vals, parse)
|
||||||
|
where
|
||||||
|
(temp, isRoot, cols, qcols, sqlWith) =
|
||||||
|
withRecursive dir field root conn t dbIdColumns
|
||||||
|
|
||||||
|
(limit, offset, orders) = limitOffsetOrder opts
|
||||||
|
|
||||||
|
parse xs = do
|
||||||
|
keyvals <-
|
||||||
|
case entityPrimary t of
|
||||||
|
Nothing ->
|
||||||
|
case xs of
|
||||||
|
[PersistInt64 x] ->
|
||||||
|
return [PersistInt64 x]
|
||||||
|
[PersistDouble x] ->
|
||||||
|
-- oracle returns Double
|
||||||
|
return [PersistInt64 $ truncate x]
|
||||||
|
_ ->
|
||||||
|
liftIO $ throwIO $ PersistMarshalError $
|
||||||
|
"Unexpected in selectKeys False: " <>
|
||||||
|
T.pack (show xs)
|
||||||
|
Just pdef ->
|
||||||
|
let pks = map fieldHaskell $ compositeFields pdef
|
||||||
|
keyvals =
|
||||||
|
map snd $
|
||||||
|
filter
|
||||||
|
(\ (a, _) ->
|
||||||
|
let ret = isJust (find (== a) pks)
|
||||||
|
in ret
|
||||||
|
) $
|
||||||
|
zip (map fieldHaskell $ entityFields t) xs
|
||||||
|
in return keyvals
|
||||||
|
case keyFromValues keyvals of
|
||||||
|
Right k -> return k
|
||||||
|
Left _ -> error "selectKeysImpl: keyFromValues failed"
|
||||||
|
t = entityDef $ dummyFromFilts filts
|
||||||
|
wher =
|
||||||
|
if null filts
|
||||||
|
then ""
|
||||||
|
else filterClause False conn filts
|
||||||
|
ord =
|
||||||
|
case map (orderClause False conn) orders of
|
||||||
|
[] -> ""
|
||||||
|
ords -> " ORDER BY " <> T.intercalate "," ords
|
||||||
|
sql =
|
||||||
|
mappend sqlWith $
|
||||||
|
connLimitOffset conn (limit, offset) (not $ null orders) $
|
||||||
|
mconcat
|
||||||
|
[ "SELECT "
|
||||||
|
, cols
|
||||||
|
, " FROM "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, wher
|
||||||
|
, ord
|
||||||
|
]
|
||||||
|
vals = getFiltsValues conn $ isRoot : filts
|
||||||
|
|
||||||
|
countRecursively dir field root filts = do
|
||||||
|
conn <- ask
|
||||||
|
let (sql, vals) = sqlAndVals conn
|
||||||
|
withRawQuery sql vals $ do
|
||||||
|
mm <- CL.head
|
||||||
|
case mm of
|
||||||
|
Just [PersistInt64 i] -> return $ fromIntegral i
|
||||||
|
Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64) -- gb oracle
|
||||||
|
Just [PersistByteString i] -> case readInteger i of -- gb mssql
|
||||||
|
Just (ret,"") -> return $ fromIntegral ret
|
||||||
|
xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]"
|
||||||
|
Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]"
|
||||||
|
Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]"
|
||||||
|
where
|
||||||
|
sqlAndVals conn = (sql, vals)
|
||||||
|
where
|
||||||
|
(temp, isRoot, cols, qcols, sqlWith) =
|
||||||
|
withRecursive dir field root conn t dbIdColumns
|
||||||
|
|
||||||
|
t = entityDef $ dummyFromFilts filts
|
||||||
|
wher =
|
||||||
|
if null filts
|
||||||
|
then ""
|
||||||
|
else filterClause False conn filts
|
||||||
|
sql = mconcat
|
||||||
|
[ sqlWith
|
||||||
|
, "SELECT COUNT(*) FROM "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, wher
|
||||||
|
]
|
||||||
|
vals = getFiltsValues conn $ isRoot : filts
|
||||||
|
|
||||||
|
-- | Same as 'deleteRecursivelyWhere', but returns the number of rows affected.
|
||||||
|
deleteRecursivelyWhereCount
|
||||||
|
:: (PersistEntity val, MonadIO m, PersistEntityBackend val ~ SqlBackend)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> ReaderT SqlBackend m Int64
|
||||||
|
deleteRecursivelyWhereCount dir field root filts = do
|
||||||
|
conn <- ask
|
||||||
|
let (sql, vals) = sqlAndVals conn
|
||||||
|
rawExecuteCount sql vals
|
||||||
|
where
|
||||||
|
sqlAndVals conn = (sql, vals)
|
||||||
|
where
|
||||||
|
(temp, isRoot, cols, qcols, sqlWith) =
|
||||||
|
withRecursive dir field root conn t dbIdColumns
|
||||||
|
|
||||||
|
t = entityDef $ dummyFromFilts filts
|
||||||
|
wher = mconcat
|
||||||
|
[ if null filts
|
||||||
|
then " WHERE ( "
|
||||||
|
else filterClause False conn filts <> " AND ( "
|
||||||
|
, connEscapeName conn $ fieldDB $ entityId t
|
||||||
|
, " IN (SELECT "
|
||||||
|
, connEscapeName conn $ fieldDB $ entityId t
|
||||||
|
, " FROM "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, ") ) "
|
||||||
|
]
|
||||||
|
sql = mconcat
|
||||||
|
[ sqlWith
|
||||||
|
, "DELETE FROM "
|
||||||
|
, connEscapeName conn $ entityDB t
|
||||||
|
, wher
|
||||||
|
]
|
||||||
|
vals = getFiltsValues conn $ isRoot : filts
|
||||||
|
|
||||||
|
-- | Same as 'updateRecursivelyWhere', but returns the number of rows affected.
|
||||||
|
updateRecursivelyWhereCount
|
||||||
|
:: (PersistEntity val, MonadIO m, SqlBackend ~ PersistEntityBackend val)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> [Filter val]
|
||||||
|
-> [Update val]
|
||||||
|
-> ReaderT SqlBackend m Int64
|
||||||
|
updateRecursivelyWhereCount _ _ _ _ [] = return 0
|
||||||
|
updateRecursivelyWhereCount dir field root filts upds = do
|
||||||
|
conn <- ask
|
||||||
|
let (sql, vals) = sqlAndVals conn
|
||||||
|
rawExecuteCount sql vals
|
||||||
|
where
|
||||||
|
sqlAndVals conn = (sql, vals)
|
||||||
|
where
|
||||||
|
(temp, isRoot, cols, qcols, sqlWith) =
|
||||||
|
withRecursive dir field root conn t dbIdColumns
|
||||||
|
|
||||||
|
t = entityDef $ dummyFromFilts filts
|
||||||
|
|
||||||
|
go'' n Assign = n <> "=?"
|
||||||
|
go'' n Add = mconcat [n, "=", n, "+?"]
|
||||||
|
go'' n Subtract = mconcat [n, "=", n, "-?"]
|
||||||
|
go'' n Multiply = mconcat [n, "=", n, "*?"]
|
||||||
|
go'' n Divide = mconcat [n, "=", n, "/?"]
|
||||||
|
go'' _ (BackendSpecificUpdate up) =
|
||||||
|
error $ T.unpack $ "BackendSpecificUpdate " <> up <> " not supported"
|
||||||
|
go' (x, pu) = go'' (connEscapeName conn x) pu
|
||||||
|
go x = (updateField x, updateUpdate x)
|
||||||
|
|
||||||
|
updateField (Update f _ _) = fieldName f
|
||||||
|
updateField _ = error "BackendUpdate not implemented"
|
||||||
|
|
||||||
|
wher = mconcat
|
||||||
|
[ if null filts
|
||||||
|
then " WHERE ( "
|
||||||
|
else filterClause False conn filts <> " AND ( "
|
||||||
|
, connEscapeName conn $ fieldDB $ entityId t
|
||||||
|
, " IN (SELECT "
|
||||||
|
, connEscapeName conn $ fieldDB $ entityId t
|
||||||
|
, " FROM "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, ") ) "
|
||||||
|
]
|
||||||
|
sql = mconcat
|
||||||
|
[ sqlWith
|
||||||
|
, "UPDATE "
|
||||||
|
, connEscapeName conn $ entityDB t
|
||||||
|
, " SET "
|
||||||
|
, T.intercalate "," $ map (go' . go) upds
|
||||||
|
, wher
|
||||||
|
]
|
||||||
|
vals =
|
||||||
|
getFiltsValues conn [isRoot] ++
|
||||||
|
map updatePersistValue upds ++
|
||||||
|
getFiltsValues conn filts
|
||||||
|
|
||||||
|
withRecursive
|
||||||
|
:: (PersistEntity val, SqlBackend ~ PersistEntityBackend val)
|
||||||
|
=> RecursionDirection
|
||||||
|
-> EntityField val (Maybe (Key val))
|
||||||
|
-> Key val
|
||||||
|
-> SqlBackend
|
||||||
|
-> EntityDef
|
||||||
|
-> (SqlBackend -> EntityDef -> [Text])
|
||||||
|
-> (DBName, Filter val, Text, DBName -> Text, Text)
|
||||||
|
withRecursive dir field root conn t getcols =
|
||||||
|
let temp = DBName "temp_hierarchy_cte"
|
||||||
|
isRoot = persistIdField ==. root
|
||||||
|
cols = T.intercalate "," $ getcols conn t
|
||||||
|
qcols name =
|
||||||
|
T.intercalate ", " $
|
||||||
|
map ((connEscapeName conn name <>) . ("." <>)) $
|
||||||
|
getcols conn t
|
||||||
|
sql = mconcat
|
||||||
|
[ "WITH RECURSIVE "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, "("
|
||||||
|
, cols
|
||||||
|
, ") AS ( SELECT "
|
||||||
|
, cols
|
||||||
|
, " FROM "
|
||||||
|
, connEscapeName conn $ entityDB t
|
||||||
|
, filterClause False conn [isRoot]
|
||||||
|
--, " WHERE "
|
||||||
|
--, connEscapeName conn $ fieldDB $ entityId t
|
||||||
|
--, " = ?"
|
||||||
|
, " UNION SELECT "
|
||||||
|
, qcols temp
|
||||||
|
, " FROM "
|
||||||
|
, connEscapeName conn $ entityDB t
|
||||||
|
, ", "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, " WHERE "
|
||||||
|
, connEscapeName conn $ entityDB t
|
||||||
|
, "."
|
||||||
|
, connEscapeName conn $ fieldDB $ case dir of
|
||||||
|
RecOut -> persistFieldDef field
|
||||||
|
RecIn -> entityId t
|
||||||
|
, " = "
|
||||||
|
, connEscapeName conn temp
|
||||||
|
, "."
|
||||||
|
, connEscapeName conn $ fieldDB $ case dir of
|
||||||
|
RecOut -> entityId t
|
||||||
|
RecIn -> persistFieldDef field
|
||||||
|
, " ) "
|
||||||
|
]
|
||||||
|
in (temp, isRoot, cols, qcols, sql)
|
|
@ -64,6 +64,9 @@ library
|
||||||
Database.Esqueleto.Local
|
Database.Esqueleto.Local
|
||||||
Database.Persist.Class.Local
|
Database.Persist.Class.Local
|
||||||
Database.Persist.Sql.Local
|
Database.Persist.Sql.Local
|
||||||
|
Database.Persist.Local.Class.PersistQueryRecursive
|
||||||
|
Database.Persist.Local.Sql.Orphan.Common
|
||||||
|
Database.Persist.Local.Sql.Orphan.PersistQueryRecursive
|
||||||
Development.DarcsRev
|
Development.DarcsRev
|
||||||
Formatting.CaseInsensitive
|
Formatting.CaseInsensitive
|
||||||
Network.SSH.Local
|
Network.SSH.Local
|
||||||
|
@ -219,6 +222,8 @@ library
|
||||||
, memory
|
, memory
|
||||||
, monad-control
|
, monad-control
|
||||||
, monad-logger
|
, monad-logger
|
||||||
|
-- for Database.Persist.Local
|
||||||
|
, mtl
|
||||||
, pandoc
|
, pandoc
|
||||||
, pandoc-types
|
, pandoc-types
|
||||||
-- for PathPiece instance for CI, Web.PathPieces.Local
|
-- for PathPiece instance for CI, Web.PathPieces.Local
|
||||||
|
@ -227,6 +232,8 @@ library
|
||||||
, persistent-postgresql
|
, persistent-postgresql
|
||||||
, persistent-template
|
, persistent-template
|
||||||
, process
|
, process
|
||||||
|
-- for Database.Persist.Local
|
||||||
|
, resourcet
|
||||||
, safe
|
, safe
|
||||||
, shakespeare
|
, shakespeare
|
||||||
, ssh
|
, ssh
|
||||||
|
|
Loading…
Reference in a new issue