1- import ProximalOperators: gradient! # this can be removed when moved to Prox
1+ import ProximalOperators: gradient!, gradient # this can be removed when moved to Prox
22
33export PrecomposeNonlinear
44
55struct PrecomposeNonlinear{P <: ProximableFunction ,
66 T <: AbstractOperator ,
7- D, C
7+ D <: AbstractArray ,
8+ C <: AbstractArray
89 } <: ProximableFunction
910 g:: P
10- G:: T
11- bufD:: D
12- bufC:: C
13- bufC2:: C
11+ G:: T
12+ bufD:: D
13+ bufC:: C
14+ bufC2:: C
1415end
1516
1617function PrecomposeNonlinear (g:: P , G:: T ) where {P, T}
17- bufD = blockzeros (domainType (G), size (G,2 ))
18- bufC = blockzeros (codomainType (G),size (G,1 ))
19- bufC2 = blockzeros (codomainType (G),size (G,1 ))
20- PrecomposeNonlinear {P, T, typeof(bufD), typeof(bufC)} (g, G, bufD, bufC, bufC2)
18+ t, s = domainType (G), size (G,2 )
19+ bufD = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s))
20+ t, s = codomainType (G), size (G,1 )
21+ bufC = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s))
22+ bufC2 = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s))
23+ PrecomposeNonlinear {P, T, typeof(bufD), typeof(bufC)} (g, G, bufD, bufC, bufC2)
2124end
2225
2326is_smooth (f:: PrecomposeNonlinear ) = is_smooth (f. g)
@@ -26,7 +29,22 @@ function (f::PrecomposeNonlinear)(x)
2629 return f. g (f. G* x)
2730end
2831
29- function gradient! (y:: D , f:: PrecomposeNonlinear{P,T,D,C} , x:: D ) where {P,T,D,C}
32+ function gradient (f:: PrecomposeNonlinear , x:: ArrayPartition )
33+ y = zero (x)
34+ fy = gradient! (y,f,x)
35+ return y, fy
36+ end
37+
38+ # TODO simplify this
39+ function gradient! (y:: D , f:: PrecomposeNonlinear{P,T,D,C} , x:: D ) where {P,T,D <: ArrayPartition ,C}
40+ mul! (f. bufC, f. G, x)
41+ v = gradient! (f. bufC2, f. g, f. bufC)
42+ J = Jacobian (f. G, x)
43+ y = mul! (y, J' , f. bufC2)
44+ return v
45+ end
46+
47+ function gradient! (y:: D , f:: PrecomposeNonlinear{P,T,D,C} , x:: D ) where {P,T,D <: AbstractArray ,C}
3048 mul! (f. bufC, f. G, x)
3149 v = gradient! (f. bufC2, f. g, f. bufC)
3250 J = Jacobian (f. G, x)
0 commit comments