From f466d7163ebd421855d154b7db0de4310517503e Mon Sep 17 00:00:00 2001 From: Vivien HENRIET Date: Fri, 20 Jun 2025 13:53:20 +0200 Subject: [PATCH] feat(assert): add support for indexed spy call assertions via assert.spy_call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch modifies the assert module to support indexed assertions on spy calls through a new entry point: `assert.spy_call(spy, n)`. It allows making assertions about a specific call by index in a clear and explicit way. Example usage: local s = spy(function() end) s("foo") s("bar") assert.spy_call(s, 1).was_called_with("foo") assert.spy_call(s, 2).was_called_with("bar") This makes it possible to verify the order and arguments of individual calls without relying on global state or chaining. Each assertion remains pure, isolated, and compatible with Luassert’s philosophy. No breaking changes. --- spec/spies_spec.lua | 24 +++++++++++++++++++++ src/spy.lua | 51 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/spec/spies_spec.lua b/spec/spies_spec.lua index d0dbe11..9baa253 100644 --- a/spec/spies_spec.lua +++ b/spec/spies_spec.lua @@ -78,6 +78,18 @@ describe("Tests dealing with spies", function() .. "%(values list%) %(%(number%) 5, %(matcher%) _ %*anything%*%)") end) -- (matcher) _ *anything* + it("checks returned_with() order assertions", function() + local s = spy.new(function(...) return ... end) + + s(1, 2, 3) + s("a", "b", "c") + assert.spy_call(s, 1).was.returned_with(1, 2, 3) + assert.spy_call(s, 2).was.returned_with("a", "b", "c") + assert.spy_call(s, 1).was_not.returned_with("a", "b", "c") + assert.spy_call(s, 2).was_not.returned_with(1, 2, 3) + assert.spy_call(s, 3).was_not.returned_with(1, 2, 3) + end) + it("checks called() and called_with() assertions", function() local s = spy.new(function() end) local t = { foo = { bar = { "test" } } } @@ -142,6 +154,18 @@ describe("Tests dealing with spies", function() assert.spy(s).was.called_with(s) end) + it("checks called_with order assertions", function() + local s = spy.new(function() end) + + s(1, 2, 3) + s("a", "b", "c") + assert.spy_call(s, 1).was.called_with(1, 2, 3) + assert.spy_call(s, 2).was.called_with("a", "b", "c") + assert.spy_call(s, 1).was_not.called_with("a", "b", "c") + assert.spy_call(s, 2).was_not.called_with(1, 2, 3) + assert.spy_call(s, 3).was_not.called_with(1, 2, 3) + end) + it("checks called_at_least() assertions", function() local s = spy.new(function() end) diff --git a/src/spy.lua b/src/spy.lua index eb7fc06..c5f75fb 100644 --- a/src/spy.lua +++ b/src/spy.lua @@ -56,24 +56,44 @@ spy = { return (#self.calls > 0), #self.calls end, - called_with = function(self, args) + called_with = function(self, args, call_number) local last_arglist = nil - if #self.calls > 0 then - last_arglist = self.calls[#self.calls].vals + local matching_arglists = nil + + if call_number ~= nil then + local call = self.calls[call_number] + if call ~= nil then + last_arglist = call.vals + matching_arglists = util.matchargs({call}, args) + end + else + if #self.calls > 0 then + last_arglist = self.calls[#self.calls].vals + end + matching_arglists = util.matchargs(self.calls, args) end - local matching_arglists = util.matchargs(self.calls, args) if matching_arglists ~= nil then return true, matching_arglists.vals end return false, last_arglist end, - returned_with = function(self, args) + returned_with = function(self, args, call_number) local last_returnvallist = nil - if #self.returnvals > 0 then - last_returnvallist = self.returnvals[#self.returnvals].vals + local matching_returnvallists = nil + + if call_number ~= nil then + local returnval = self.returnvals[call_number] + if returnval ~= nil then + last_returnvallist = returnval.vals + matching_returnvallists = util.matchargs({returnval}, args) + end + else + if #self.returnvals > 0 then + last_returnvallist = self.returnvals[#self.returnvals].vals + end + matching_returnvallists = util.matchargs(self.returnvals, args) end - local matching_returnvallists = util.matchargs(self.returnvals, args) if matching_returnvallists ~= nil then return true, matching_returnvallists.vals end @@ -106,11 +126,20 @@ local function set_spy(state, arguments, level) end end +local function set_spy_call(state, arguments, level) + state.payload = arguments[1] + state.call_number = arguments[2] + if arguments[3] ~= nil then + state.failure_message = arguments[3] + end +end + local function returned_with(state, arguments, level) local level = (level or 1) + 1 local payload = rawget(state, "payload") + local call_number = rawget(state, "call_number") if payload and payload.returned_with then - local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments) + local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments, call_number) local expected_returnvallist = util.shallowcopy(arguments) util.cleararglist(arguments) util.tinsert(arguments, 1, matching_or_last_returnvallist) @@ -124,8 +153,9 @@ end local function called_with(state, arguments, level) local level = (level or 1) + 1 local payload = rawget(state, "payload") + local call_number = rawget(state, "call_number") if payload and payload.called_with then - local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments) + local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments, call_number) local expected_arglist = util.shallowcopy(arguments) util.cleararglist(arguments) util.tinsert(arguments, 1, matching_or_last_arglist) @@ -180,6 +210,7 @@ local function called_less_than(state, arguments, level) end assert:register("modifier", "spy", set_spy) +assert:register("modifier", "spy_call", set_spy_call) assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive", "assertion.returned_with.negative") assert:register("assertion", "called_with", called_with, "assertion.called_with.positive", "assertion.called_with.negative") assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative")