@@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout)
1212 @test dot (Δin, rΔin[2 : end ]) ≈ dot (fΔout, Δout)
1313end
1414
15+ function _dot (p, q)
16+ monos = monovec ([monomials (p); monomials (q)])
17+ return dot (coefficient .(p, monos), coefficient .(q, monos))
18+ end
19+ function _dot (px:: Tuple , qx:: Tuple )
20+ return _dot (first (px), first (qx)) + _dot (Base. tail (px), Base. tail (qx))
21+ end
22+ function _dot (:: Tuple{} , :: Tuple{} )
23+ return MultivariatePolynomials. MA. Zero ()
24+ end
25+ function _dot (:: NoTangent , :: NoTangent )
26+ return MultivariatePolynomials. MA. Zero ()
27+ end
28+
1529@testset " ChainRulesCore" begin
1630 Mod. @polyvar x y
1731 p = 1.1 x + y
4256 @test pullback (q) == (NoTangent (), (- 0.2 + 2im ) * x^ 2 - x* y, NoTangent ())
4357 @test pullback (1 x) == (NoTangent (), 2 x^ 2 , NoTangent ())
4458
45- test_chain_rule (dot, + , (p,), (q,), p)
46- test_chain_rule (dot, + , (q,), (p,), q)
59+ for d in [dot, _dot]
60+ test_chain_rule (d, + , (p,), (q,), p)
61+ test_chain_rule (d, + , (q,), (p,), q)
4762
48- test_chain_rule (dot , - , (p,), (q,), p)
49- test_chain_rule (dot , - , (p,), (p,), q)
63+ test_chain_rule (d , - , (p,), (q,), p)
64+ test_chain_rule (d , - , (p,), (p,), q)
5065
51- test_chain_rule (dot , + , (p, q), (q, p), p)
52- test_chain_rule (dot , + , (p, q), (p, q), q)
66+ test_chain_rule (d , + , (p, q), (q, p), p)
67+ test_chain_rule (d , + , (p, q), (p, q), q)
5368
54- test_chain_rule (dot, - , (p, q), (q, p), p)
55- test_chain_rule (dot, - , (p, q), (p, q), q)
69+ test_chain_rule (d, - , (p, q), (q, p), p)
70+ test_chain_rule (d, - , (p, q), (p, q), q)
71+ end
5672
57- test_chain_rule (dot , * , (p, q), (q, p), p * q)
58- test_chain_rule (dot , * , (p, q), (p, q), q * q)
59- test_chain_rule (dot , * , (q, p), (p, q), q * q)
60- test_chain_rule (dot , * , (p, q), (q, p), q * q)
73+ test_chain_rule (_dot , * , (p, q), (q, p), p * q)
74+ test_chain_rule (_dot , * , (p, q), (p, q), q * q)
75+ test_chain_rule (_dot , * , (q, p), (p, q), q * q)
76+ test_chain_rule (_dot , * , (p, q), (q, p), q * q)
6177
62- function _dot (p, q)
63- monos = monomials (p + q)
64- return dot (coefficient .(p, monos), coefficient .(q, monos))
65- end
66- function _dot (px:: Tuple{<:AbstractPolynomial,NoTangent} , qx:: Tuple{<:AbstractPolynomial,NoTangent} )
67- return _dot (px[1 ], qx[1 ])
68- end
6978 test_chain_rule (_dot, differentiate, (p, x), (q, NoTangent ()), p)
7079 test_chain_rule (_dot, differentiate, (p, x), (q, NoTangent ()), differentiate (p, x))
7180 test_chain_rule (_dot, differentiate, (p, x), (q, NoTangent ()), differentiate (q, x))
0 commit comments