11package com .avsystem .commons
22package jetty .rpc
33
4- import java .nio .charset .StandardCharsets
54import com .avsystem .commons .rpc .StandardRPCFramework
65import com .avsystem .commons .serialization .json .{JsonStringInput , JsonStringOutput , RawJson }
76import com .avsystem .commons .serialization .{GenCodec , HasGenCodec }
87import com .typesafe .scalalogging .LazyLogging
9-
10- import javax .servlet .http .{HttpServletRequest , HttpServletResponse }
11- import org .eclipse .jetty .client .HttpClient
12- import org .eclipse .jetty .client .api .Result
13- import org .eclipse .jetty .client .util .{BufferingResponseListener , StringContentProvider , StringRequestContent }
8+ import jakarta .servlet .http .{HttpServlet , HttpServletRequest , HttpServletResponse }
9+ import org .eclipse .jetty .client .{BufferingResponseListener , HttpClient , Result , StringRequestContent }
10+ import org .eclipse .jetty .ee10 .servlet .ServletContextHandler
1411import org .eclipse .jetty .http .{HttpMethod , HttpStatus , MimeTypes }
15- import org .eclipse .jetty .server .handler .AbstractHandler
16- import org .eclipse .jetty .server .{Handler , Request }
12+ import org .eclipse .jetty .server .Handler
1713
18- import scala .concurrent .duration ._
14+ import java .nio .charset .StandardCharsets
15+ import java .util .concurrent .atomic .AtomicBoolean
16+ import scala .concurrent .duration .*
17+ import scala .util .Using
1918
2019object JettyRPCFramework extends StandardRPCFramework with LazyLogging {
2120 class RawValue (val s : String ) extends AnyVal
@@ -89,30 +88,40 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging {
8988 request(HttpMethod .PUT , call)
9089 }
9190
92- class RPCHandler (rootRpc : RawRPC , contextTimeout : FiniteDuration ) extends AbstractHandler {
93- override def handle (target : String , baseRequest : Request , request : HttpServletRequest , response : HttpServletResponse ): Unit = {
94- baseRequest.setHandled(true )
95-
96- val content = Iterator .continually(request.getReader.readLine())
97- .takeWhile(_ != null )
98- .mkString(" \n " )
99-
100- val call = read[Call ](new RawValue (content))
91+ class RPCHandler (rootRpc : RawRPC , contextTimeout : FiniteDuration ) extends HttpServlet {
92+ override def service (request : HttpServletRequest , response : HttpServletResponse ): Unit = {
93+ // readRequest must execute in request thread but we want exceptions to be handled uniformly, hence the Try
94+ val content =
95+ Using (request.getReader)(reader =>
96+ Iterator .continually(reader.readLine()).takeWhile(_ != null ).mkString(" \n " )
97+ )
98+ val call = content.map(content => read[Call ](new RawValue (content)))
10199
102100 HttpMethod .fromString(request.getMethod) match {
103101 case HttpMethod .POST =>
104- val async = request.startAsync().setup(_.setTimeout(contextTimeout.toMillis))
105- handlePost(call).andThenNow {
102+ val asyncContext = request.startAsync().setup(_.setTimeout(contextTimeout.toMillis))
103+ val completed = new AtomicBoolean (false )
104+ // Need to protect asyncContext from being completed twice because after a timeout the
105+ // servlet may recycle the same context instance between subsequent requests (not cool)
106+ // https://stackoverflow.com/a/27744537
107+ def completeWith (code : => Unit ): Unit =
108+ if (! completed.getAndSet(true )) {
109+ code
110+ asyncContext.complete()
111+ }
112+ Future .fromTry(call).flatMapNow(handlePost).onCompleteNow {
106113 case Success (responseContent) =>
107- response.setContentType(MimeTypes .Type .APPLICATION_JSON .asString())
108- response.setCharacterEncoding(StandardCharsets .UTF_8 .name())
109- response.getWriter.write(responseContent.s)
114+ completeWith {
115+ response.setContentType(MimeTypes .Type .APPLICATION_JSON .asString())
116+ response.setCharacterEncoding(StandardCharsets .UTF_8 .name())
117+ response.getWriter.write(responseContent.s)
118+ }
110119 case Failure (t) =>
111- response.sendError(HttpStatus .INTERNAL_SERVER_ERROR_500 , t.getMessage)
120+ completeWith( response.sendError(HttpStatus .INTERNAL_SERVER_ERROR_500 , t.getMessage) )
112121 logger.error(" Failed to handle RPC call" , t)
113- }.andThenNow { case _ => async.complete() }
122+ }
114123 case HttpMethod .PUT =>
115- handlePut( call)
124+ call.map(handlePut).get
116125 case _ =>
117126 throw new IllegalArgumentException (s " Request HTTP method is ${request.getMethod}, only POST or PUT are supported " )
118127 }
@@ -132,11 +141,12 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging {
132141 invoke(call)(_.fire)
133142 }
134143
135- def newHandler [T ](impl : T , contextTimeout : FiniteDuration = 30 .seconds)(
136- implicit asRawRPC : AsRawRPC [T ]): Handler =
137- new RPCHandler (asRawRPC.asRaw(impl), contextTimeout)
144+ def newServlet [T : AsRawRPC ](impl : T , contextTimeout : FiniteDuration = 30 .seconds): HttpServlet =
145+ new RPCHandler (AsRawRPC [T ].asRaw(impl), contextTimeout)
146+
147+ def newHandler [T : AsRawRPC ](impl : T , contextTimeout : FiniteDuration = 30 .seconds): Handler =
148+ new ServletContextHandler ().setup(_.addServlet(newServlet(impl, contextTimeout), " /*" ))
138149
139- def newClient [T ](httpClient : HttpClient , uri : String , maxResponseLength : Int = 2 * 1024 * 1024 )(
140- implicit asRealRPC : AsRealRPC [T ]): T =
141- asRealRPC.asReal(new RPCClient (httpClient, uri, maxResponseLength).rawRPC)
150+ def newClient [T : AsRealRPC ](httpClient : HttpClient , uri : String , maxResponseLength : Int = 2 * 1024 * 1024 ): T =
151+ AsRealRPC [T ].asReal(new RPCClient (httpClient, uri, maxResponseLength).rawRPC)
142152}
0 commit comments