Skip to content

Commit

Permalink
rewrite modle:update to not commit changes to the instance object unt…
Browse files Browse the repository at this point in the history
…il the update has completed successfully
  • Loading branch information
leafo committed Nov 2, 2023
1 parent 1cbd6cc commit 6952cd8
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 58 deletions.
107 changes: 77 additions & 30 deletions lapis/db/base_model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,18 @@ do
end,
update = function(self, first, ...)
local cond = self:_primary_cond()
local update_fields = { }
local columns
if type(first) == "table" then
do
local _accum_0 = { }
local _len_0 = 1
for k, v in pairs(first) do
if type(k) == "number" then
update_fields[v] = self[v]
_accum_0[_len_0] = v
else
self[k] = v
update_fields[k] = v
_accum_0[_len_0] = k
end
_len_0 = _len_0 + 1
Expand All @@ -288,29 +290,25 @@ do
first,
...
}
for _index_0 = 1, #columns do
local c = columns[_index_0]
update_fields[c] = self[c]
end
end
if next(columns) == nil then
return nil, "nothing to update"
end
if self.__class.constraints then
for _, column in pairs(columns) do
for _index_0 = 1, #columns do
local column = columns[_index_0]
do
local err = self.__class:_check_constraint(column, self[column], self)
local err = self.__class:_check_constraint(column, update_fields[column], self)
if err then
return nil, err
end
end
end
end
local values
do
local _tbl_0 = { }
for _index_0 = 1, #columns do
local col = columns[_index_0]
_tbl_0[col] = self[col]
end
values = _tbl_0
end
local nargs = select("#", ...)
local last = nargs > 0 and select(nargs, ...)
local opts
Expand All @@ -319,7 +317,7 @@ do
end
if self.__class.timestamp and not (opts and opts.timestamp == false) then
local time = self.__class.db.format_date()
values.updated_at = values.updated_at or time
update_fields.updated_at = update_fields.updated_at or time
end
if opts and opts.where then
assert(type(opts.where) == "table", "Model.update: where condition must be a table or db.clause")
Expand All @@ -334,31 +332,80 @@ do
where
})
end
local returning
for k, v in pairs(values) do
if v == self.__class.db.NULL then
self[k] = nil
elseif self.__class.db.is_raw(v) then
returning = returning or { }
table.insert(returning, k)
local returning, return_all
if opts and opts.returning then
if opts.returning == "*" then
return_all = true
returning = {
self.__class.db.raw("*")
}
else
returning = {
unpack(opts.returning)
}
end
end
for k, v in pairs(update_fields) do
local _continue_0 = false
repeat
if v == self.__class.db.NULL then
_continue_0 = true
break
end
if self.__class.db.is_raw(v) then
returning = returning or { }
table.insert(returning, k)
end
_continue_0 = true
until true
if not _continue_0 then
break
end
end
local res
if returning then
res = self.__class.db.update(self.__class:table_name(), values, cond, unpack(returning))
do
local update = unpack(res)
if update then
for _index_0 = 1, #returning do
local k = returning[_index_0]
self[k] = update[k]
res = self.__class.db.update(self.__class:table_name(), update_fields, cond, unpack(returning))
else
res = self.__class.db.update(self.__class:table_name(), update_fields, cond)
end
local did_update = (res.affected_rows or 0) > 0
if did_update then
for k, v in pairs(update_fields) do
if v == self.__class.db.NULL then
self[k] = nil
else
self[k] = v
end
end
if returning then
do
local result_row = unpack(res)
if result_row then
if return_all then
for k, v in pairs(result_row) do
self[k] = v
end
end
for _index_0 = 1, #returning do
local _continue_0 = false
repeat
local k = returning[_index_0]
if not (type(k) == "string") then
_continue_0 = true
break
end
self[k] = result_row[k]
_continue_0 = true
until true
if not _continue_0 then
break
end
end
end
end
end
else
res = self.__class.db.update(self.__class:table_name(), values, cond)
end
return (res.affected_rows or 0) > 0, res
return did_update, res
end,
refresh = function(self, fields, ...)
if fields == nil then
Expand Down
76 changes: 53 additions & 23 deletions lapis/db/base_model.moon
Original file line number Diff line number Diff line change
Expand Up @@ -671,28 +671,33 @@ class BaseModel
-- col3: "Hello"
-- }
-- NOTE: this implementation depends on support for RETURNING sql synax
-- TODO: update by field name should be deprecated
update: (first, ...) =>
cond = @_primary_cond!

columns = if type(first) == "table"
for k,v in pairs first
update_fields = {} -- the columns and their new values to update
local columns

if type(first) == "table"
columns = for k,v in pairs first
if type(k) == "number"
update_fields[v] = @[v]
v
else
@[k] = v
update_fields[k] = v
k
else
{first, ...}
columns = {first, ...}
for c in *columns
update_fields[c] = @[c]

return nil, "nothing to update" if next(columns) == nil

if @@constraints
for _, column in pairs columns
if err = @@_check_constraint column, @[column], @
for column in *columns
if err = @@_check_constraint column, update_fields[column], @
return nil, err

values = { col, @[col] for col in *columns }

-- update options
nargs = select "#", ...
last = nargs > 0 and select nargs, ...
Expand All @@ -701,7 +706,7 @@ class BaseModel

if @@timestamp and not (opts and opts.timestamp == false)
time = @@db.format_date!
values.updated_at or= time
update_fields.updated_at or= time

if opts and opts.where
assert type(opts.where) == "table", "Model.update: where condition must be a table or db.clause"
Expand All @@ -716,26 +721,51 @@ class BaseModel
where
}

local returning
for k, v in pairs values
if v == @@db.NULL
@[k] = nil
elseif @@db.is_raw(v)
local returning, return_all -- TODO: verify that returning * works

if opts and opts.returning
if opts.returning == "*"
return_all = true
returning = {@@db.raw "*"}
else
returning = {unpack opts.returning}

for k, v in pairs update_fields
continue if v == @@db.NULL -- NULL is raw but handled specially
if @@db.is_raw v
returning or= {}
table.insert returning, k

local res

if returning
res = @@db.update @@table_name!, values, cond, unpack returning
if update = unpack res
for k in *returning
@[k] = update[k]
res = if returning
@@db.update @@table_name!, update_fields, cond, unpack returning
else
res = @@db.update @@table_name!, values, cond
@@db.update @@table_name!, update_fields, cond

(res.affected_rows or 0) > 0, res
did_update = (res.affected_rows or 0) > 0

-- if the update completed, store the values into self
if did_update
-- NOTE: this is redundant if the column name variant is used to issue an update
for k, v in pairs update_fields
if v == @@db.NULL
@[k] = nil
else
@[k] = v

if returning
if result_row = unpack res
if return_all
for k, v in pairs result_row
@[k] = v

-- we still have to iterate over the name list to ensure that we nil
-- out explicitly requested fields, since db.NULL is not returned in
-- the result set
for k in *returning
continue unless type(k) == "string"
@[k] = result_row[k]

did_update, res

-- reload fields on the instance
refresh: (fields="*", ...) =>
Expand Down
Loading

0 comments on commit 6952cd8

Please sign in to comment.