diff --git a/spec/spies_spec.lua b/spec/spies_spec.lua index d0dbe11..9baa253 100644 --- a/spec/spies_spec.lua +++ b/spec/spies_spec.lua @@ -78,6 +78,18 @@ describe("Tests dealing with spies", function() .. "%(values list%) %(%(number%) 5, %(matcher%) _ %*anything%*%)") end) -- (matcher) _ *anything* + it("checks returned_with() order assertions", function() + local s = spy.new(function(...) return ... end) + + s(1, 2, 3) + s("a", "b", "c") + assert.spy_call(s, 1).was.returned_with(1, 2, 3) + assert.spy_call(s, 2).was.returned_with("a", "b", "c") + assert.spy_call(s, 1).was_not.returned_with("a", "b", "c") + assert.spy_call(s, 2).was_not.returned_with(1, 2, 3) + assert.spy_call(s, 3).was_not.returned_with(1, 2, 3) + end) + it("checks called() and called_with() assertions", function() local s = spy.new(function() end) local t = { foo = { bar = { "test" } } } @@ -142,6 +154,18 @@ describe("Tests dealing with spies", function() assert.spy(s).was.called_with(s) end) + it("checks called_with order assertions", function() + local s = spy.new(function() end) + + s(1, 2, 3) + s("a", "b", "c") + assert.spy_call(s, 1).was.called_with(1, 2, 3) + assert.spy_call(s, 2).was.called_with("a", "b", "c") + assert.spy_call(s, 1).was_not.called_with("a", "b", "c") + assert.spy_call(s, 2).was_not.called_with(1, 2, 3) + assert.spy_call(s, 3).was_not.called_with(1, 2, 3) + end) + it("checks called_at_least() assertions", function() local s = spy.new(function() end) diff --git a/src/spy.lua b/src/spy.lua index eb7fc06..c5f75fb 100644 --- a/src/spy.lua +++ b/src/spy.lua @@ -56,24 +56,44 @@ spy = { return (#self.calls > 0), #self.calls end, - called_with = function(self, args) + called_with = function(self, args, call_number) local last_arglist = nil - if #self.calls > 0 then - last_arglist = self.calls[#self.calls].vals + local matching_arglists = nil + + if call_number ~= nil then + local call = self.calls[call_number] + if call ~= nil then + last_arglist = call.vals + matching_arglists = util.matchargs({call}, args) + end + else + if #self.calls > 0 then + last_arglist = self.calls[#self.calls].vals + end + matching_arglists = util.matchargs(self.calls, args) end - local matching_arglists = util.matchargs(self.calls, args) if matching_arglists ~= nil then return true, matching_arglists.vals end return false, last_arglist end, - returned_with = function(self, args) + returned_with = function(self, args, call_number) local last_returnvallist = nil - if #self.returnvals > 0 then - last_returnvallist = self.returnvals[#self.returnvals].vals + local matching_returnvallists = nil + + if call_number ~= nil then + local returnval = self.returnvals[call_number] + if returnval ~= nil then + last_returnvallist = returnval.vals + matching_returnvallists = util.matchargs({returnval}, args) + end + else + if #self.returnvals > 0 then + last_returnvallist = self.returnvals[#self.returnvals].vals + end + matching_returnvallists = util.matchargs(self.returnvals, args) end - local matching_returnvallists = util.matchargs(self.returnvals, args) if matching_returnvallists ~= nil then return true, matching_returnvallists.vals end @@ -106,11 +126,20 @@ local function set_spy(state, arguments, level) end end +local function set_spy_call(state, arguments, level) + state.payload = arguments[1] + state.call_number = arguments[2] + if arguments[3] ~= nil then + state.failure_message = arguments[3] + end +end + local function returned_with(state, arguments, level) local level = (level or 1) + 1 local payload = rawget(state, "payload") + local call_number = rawget(state, "call_number") if payload and payload.returned_with then - local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments) + local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments, call_number) local expected_returnvallist = util.shallowcopy(arguments) util.cleararglist(arguments) util.tinsert(arguments, 1, matching_or_last_returnvallist) @@ -124,8 +153,9 @@ end local function called_with(state, arguments, level) local level = (level or 1) + 1 local payload = rawget(state, "payload") + local call_number = rawget(state, "call_number") if payload and payload.called_with then - local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments) + local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments, call_number) local expected_arglist = util.shallowcopy(arguments) util.cleararglist(arguments) util.tinsert(arguments, 1, matching_or_last_arglist) @@ -180,6 +210,7 @@ local function called_less_than(state, arguments, level) end assert:register("modifier", "spy", set_spy) +assert:register("modifier", "spy_call", set_spy_call) assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive", "assertion.returned_with.negative") assert:register("assertion", "called_with", called_with, "assertion.called_with.positive", "assertion.called_with.negative") assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative")