Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions spec/spies_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ describe("Tests dealing with spies", function()
.. "%(values list%) %(%(number%) 5, %(matcher%) _ %*anything%*%)")
end) -- (matcher) _ *anything*

it("checks returned_with() order assertions", function()
local s = spy.new(function(...) return ... end)

s(1, 2, 3)
s("a", "b", "c")
assert.spy_call(s, 1).was.returned_with(1, 2, 3)
assert.spy_call(s, 2).was.returned_with("a", "b", "c")
assert.spy_call(s, 1).was_not.returned_with("a", "b", "c")
assert.spy_call(s, 2).was_not.returned_with(1, 2, 3)
assert.spy_call(s, 3).was_not.returned_with(1, 2, 3)
end)

it("checks called() and called_with() assertions", function()
local s = spy.new(function() end)
local t = { foo = { bar = { "test" } } }
Expand Down Expand Up @@ -142,6 +154,18 @@ describe("Tests dealing with spies", function()
assert.spy(s).was.called_with(s)
end)

it("checks called_with order assertions", function()
local s = spy.new(function() end)

s(1, 2, 3)
s("a", "b", "c")
assert.spy_call(s, 1).was.called_with(1, 2, 3)
assert.spy_call(s, 2).was.called_with("a", "b", "c")
assert.spy_call(s, 1).was_not.called_with("a", "b", "c")
assert.spy_call(s, 2).was_not.called_with(1, 2, 3)
assert.spy_call(s, 3).was_not.called_with(1, 2, 3)
end)

it("checks called_at_least() assertions", function()
local s = spy.new(function() end)

Expand Down
51 changes: 41 additions & 10 deletions src/spy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,44 @@ spy = {
return (#self.calls > 0), #self.calls
end,

called_with = function(self, args)
called_with = function(self, args, call_number)
local last_arglist = nil
if #self.calls > 0 then
last_arglist = self.calls[#self.calls].vals
local matching_arglists = nil

if call_number ~= nil then
local call = self.calls[call_number]
if call ~= nil then
last_arglist = call.vals
matching_arglists = util.matchargs({call}, args)
end
else
if #self.calls > 0 then
last_arglist = self.calls[#self.calls].vals
end
matching_arglists = util.matchargs(self.calls, args)
end
local matching_arglists = util.matchargs(self.calls, args)
if matching_arglists ~= nil then
return true, matching_arglists.vals
end
return false, last_arglist
end,

returned_with = function(self, args)
returned_with = function(self, args, call_number)
local last_returnvallist = nil
if #self.returnvals > 0 then
last_returnvallist = self.returnvals[#self.returnvals].vals
local matching_returnvallists = nil

if call_number ~= nil then
local returnval = self.returnvals[call_number]
if returnval ~= nil then
last_returnvallist = returnval.vals
matching_returnvallists = util.matchargs({returnval}, args)
end
else
if #self.returnvals > 0 then
last_returnvallist = self.returnvals[#self.returnvals].vals
end
matching_returnvallists = util.matchargs(self.returnvals, args)
end
local matching_returnvallists = util.matchargs(self.returnvals, args)
if matching_returnvallists ~= nil then
return true, matching_returnvallists.vals
end
Expand Down Expand Up @@ -106,11 +126,20 @@ local function set_spy(state, arguments, level)
end
end

local function set_spy_call(state, arguments, level)
state.payload = arguments[1]
state.call_number = arguments[2]
if arguments[3] ~= nil then
state.failure_message = arguments[3]
end
end

local function returned_with(state, arguments, level)
local level = (level or 1) + 1
local payload = rawget(state, "payload")
local call_number = rawget(state, "call_number")
if payload and payload.returned_with then
local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments)
local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments, call_number)
local expected_returnvallist = util.shallowcopy(arguments)
util.cleararglist(arguments)
util.tinsert(arguments, 1, matching_or_last_returnvallist)
Expand All @@ -124,8 +153,9 @@ end
local function called_with(state, arguments, level)
local level = (level or 1) + 1
local payload = rawget(state, "payload")
local call_number = rawget(state, "call_number")
if payload and payload.called_with then
local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments)
local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments, call_number)
local expected_arglist = util.shallowcopy(arguments)
util.cleararglist(arguments)
util.tinsert(arguments, 1, matching_or_last_arglist)
Expand Down Expand Up @@ -180,6 +210,7 @@ local function called_less_than(state, arguments, level)
end

assert:register("modifier", "spy", set_spy)
assert:register("modifier", "spy_call", set_spy_call)
assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive", "assertion.returned_with.negative")
assert:register("assertion", "called_with", called_with, "assertion.called_with.positive", "assertion.called_with.negative")
assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative")
Expand Down