-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate servers from tapir-loom #3304
Changes from 1 commit
e418a0e
76dfa35
25b4106
9a03a46
49af16b
1492132
a52779a
dcf8d22
8fe4593
a3f6643
ee87d44
e5ab8f2
10c944c
d33aff8
87235c3
5273d36
bc70301
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
package sttp.tapir.server.netty.loom | ||
|
||
import io.netty.channel.Channel | ||
import io.netty.channel.EventLoopGroup | ||
import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} | ||
import io.netty.channel.unix.DomainSocketAddress | ||
import io.netty.util.concurrent.DefaultEventExecutor | ||
import sttp.tapir.server.ServerEndpoint | ||
import sttp.tapir.server.model.ServerResponse | ||
import sttp.tapir.server.netty.NettyConfig | ||
import sttp.tapir.server.netty.NettyResponse | ||
import sttp.tapir.server.netty.Route | ||
import sttp.tapir.server.netty.internal.NettyBootstrap | ||
import sttp.tapir.server.netty.internal.NettyServerHandler | ||
|
||
import java.net.InetSocketAddress | ||
import java.net.SocketAddress | ||
import java.nio.file.Path | ||
import java.nio.file.Paths | ||
import java.util.UUID | ||
import java.util.concurrent.Executors | ||
import java.util.concurrent.{Future => JFuture} | ||
import java.util.concurrent.atomic.AtomicBoolean | ||
import scala.concurrent.Future | ||
import scala.concurrent.Promise | ||
import scala.concurrent.duration.FiniteDuration | ||
import scala.util.control.NonFatal | ||
|
||
case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, config: NettyConfig) { | ||
private val executor = Executors.newVirtualThreadPerTaskExecutor() | ||
|
||
def addEndpoint(se: ServerEndpoint[Any, Id]): NettyIdServer = addEndpoints(List(se)) | ||
def addEndpoint(se: ServerEndpoint[Any, Id], overrideOptions: NettyIdServerOptions): NettyIdServer = | ||
addEndpoints(List(se), overrideOptions) | ||
def addEndpoints(ses: List[ServerEndpoint[Any, Id]]): NettyIdServer = addRoute(NettyIdServerInterpreter(options).toRoute(ses)) | ||
def addEndpoints(ses: List[ServerEndpoint[Any, Id]], overrideOptions: NettyIdServerOptions): NettyIdServer = | ||
addRoute(NettyIdServerInterpreter(overrideOptions).toRoute(ses)) | ||
|
||
def addRoute(r: IdRoute): NettyIdServer = copy(routes = routes :+ r) | ||
def addRoutes(r: Iterable[IdRoute]): NettyIdServer = copy(routes = routes ++ r) | ||
|
||
def options(o: NettyIdServerOptions): NettyIdServer = copy(options = o) | ||
def config(c: NettyConfig): NettyIdServer = copy(config = c) | ||
def modifyConfig(f: NettyConfig => NettyConfig): NettyIdServer = config(f(config)) | ||
|
||
def host(hostname: String): NettyIdServer = modifyConfig(_.host(hostname)) | ||
|
||
def port(p: Int): NettyIdServer = modifyConfig(_.port(p)) | ||
|
||
def start(): NettyIdServerBinding = | ||
startUsingSocketOverride[InetSocketAddress](None) match { | ||
case (socket, stop) => | ||
NettyIdServerBinding(socket, stop) | ||
} | ||
|
||
def startUsingDomainSocket(path: Option[Path] = None): NettyIdDomainSocketBinding = | ||
startUsingDomainSocket(path.getOrElse(Paths.get(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString))) | ||
|
||
def startUsingDomainSocket(path: Path): NettyIdDomainSocketBinding = | ||
startUsingSocketOverride(Some(new DomainSocketAddress(path.toFile))) match { | ||
case (socket, stop) => | ||
NettyIdDomainSocketBinding(socket, stop) | ||
} | ||
|
||
private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): (SA, () => Unit) = { | ||
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() | ||
val route = Route.combine(routes) | ||
|
||
def unsafeRunF( | ||
callToExecute: () => Id[ServerResponse[NettyResponse]] | ||
): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { | ||
val scalaPromise = Promise[ServerResponse[NettyResponse]]() | ||
val jFuture: JFuture[?] = executor.submit(new Runnable { | ||
override def run(): Unit = try { | ||
val result = callToExecute() | ||
scalaPromise.success(result) | ||
} catch { | ||
case NonFatal(e) => scalaPromise.failure(e) | ||
} | ||
}) | ||
|
||
( | ||
scalaPromise.future, | ||
() => { | ||
jFuture.cancel(true) | ||
Future.unit | ||
} | ||
) | ||
} | ||
val channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe | ||
val isShuttingDown: AtomicBoolean = new AtomicBoolean(false) | ||
|
||
val channelIdFuture = NettyBootstrap( | ||
config, | ||
new NettyServerHandler( | ||
route, | ||
unsafeRunF, | ||
config.maxContentLength, | ||
channelGroup, | ||
isShuttingDown | ||
), | ||
eventLoopGroup, | ||
socketOverride | ||
) | ||
channelIdFuture.await() | ||
val channelId = channelIdFuture.channel() | ||
|
||
( | ||
channelId.localAddress().asInstanceOf[SA], | ||
() => stop(channelId, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout) | ||
) | ||
} | ||
|
||
private def waitForClosedChannels( | ||
channelGroup: ChannelGroup, | ||
startNanos: Long, | ||
gracefulShutdownTimeoutNanos: Option[Long] | ||
): Unit = { | ||
while (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) { | ||
Thread.sleep(100) | ||
} | ||
val _ = channelGroup.close().get() | ||
} | ||
private def stop( | ||
ch: Channel, | ||
eventLoopGroup: EventLoopGroup, | ||
channelGroup: ChannelGroup, | ||
isShuttingDown: AtomicBoolean, | ||
gracefulShutdownTimeout: Option[FiniteDuration] | ||
): Unit = { | ||
isShuttingDown.set(true) | ||
waitForClosedChannels( | ||
channelGroup, | ||
startNanos = System.nanoTime(), | ||
gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos) | ||
) | ||
ch.close().get() | ||
if (config.shutdownEventLoopGroupOnClose) { | ||
val _ = eventLoopGroup.shutdownGracefully().get() | ||
} | ||
} | ||
} | ||
|
||
object NettyIdServer { | ||
def apply(): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, NettyConfig.defaultNoStreaming) | ||
|
||
def apply(serverOptions: NettyIdServerOptions): NettyIdServer = | ||
NettyIdServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) | ||
|
||
def apply(config: NettyConfig): NettyIdServer = | ||
NettyIdServer(Vector.empty, NettyIdServerOptions.default, config) | ||
|
||
def apply(serverOptions: NettyIdServerOptions, config: NettyConfig): NettyIdServer = | ||
NettyIdServer(Vector.empty, serverOptions, config) | ||
} | ||
case class NettyIdServerBinding(localSocket: InetSocketAddress, stop: () => Unit) { | ||
def hostName: String = localSocket.getHostName | ||
def port: Int = localSocket.getPort | ||
} | ||
case class NettyIdDomainSocketBinding(localSocket: DomainSocketAddress, stop: () => Unit) { | ||
def path: String = localSocket.path() | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package sttp.tapir.server.netty.loom | ||
|
||
import sttp.tapir.server.ServerEndpoint | ||
import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} | ||
|
||
trait NettyIdServerInterpreter { | ||
def nettyServerOptions: NettyIdServerOptions | ||
|
||
def toRoute( | ||
ses: List[ServerEndpoint[Any, Id]] | ||
): IdRoute = { | ||
NettyServerInterpreter.toRoute[Id]( | ||
ses, | ||
nettyServerOptions.interceptors, | ||
nettyServerOptions.createFile, | ||
nettyServerOptions.deleteFile, | ||
new RunAsync[Id] { | ||
override def apply[T](f: => Id[T]): Unit = { | ||
val _ = f | ||
() | ||
} | ||
} | ||
) | ||
} | ||
} | ||
|
||
object NettyIdServerInterpreter { | ||
def apply(serverOptions: NettyIdServerOptions = NettyIdServerOptions.default): NettyIdServerInterpreter = { | ||
new NettyIdServerInterpreter { | ||
override def nettyServerOptions: NettyIdServerOptions = serverOptions | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package sttp.tapir.server.netty.loom | ||
|
||
import com.typesafe.scalalogging.Logger | ||
import sttp.tapir.model.ServerRequest | ||
import sttp.tapir.server.interceptor.log.{DefaultServerLog, ServerLog} | ||
import sttp.tapir.server.netty.internal.NettyDefaults | ||
import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor} | ||
import sttp.tapir.{Defaults, TapirFile} | ||
|
||
case class NettyIdServerOptions( | ||
interceptors: List[Interceptor[Id]], | ||
createFile: ServerRequest => TapirFile, | ||
deleteFile: TapirFile => Unit | ||
) { | ||
def prependInterceptor(i: Interceptor[Id]): NettyIdServerOptions = copy(interceptors = i :: interceptors) | ||
def appendInterceptor(i: Interceptor[Id]): NettyIdServerOptions = copy(interceptors = interceptors :+ i) | ||
} | ||
|
||
object NettyIdServerOptions { | ||
|
||
/** Default options, using TCP sockets (the most common case). This can be later customised using [[NettyIdServerOptions#nettyOptions()]]. | ||
*/ | ||
def default: NettyIdServerOptions = customiseInterceptors.options | ||
|
||
private def default( | ||
interceptors: List[Interceptor[Id]] | ||
): NettyIdServerOptions = | ||
NettyIdServerOptions( | ||
interceptors, | ||
_ => Defaults.createTempFile(), | ||
Defaults.deleteFile() | ||
) | ||
|
||
/** Customise the interceptors that are being used when exposing endpoints as a server. By default uses TCP sockets (the most common | ||
* case), but this can be later customised using [[NettyIdServerOptions#nettyOptions()]]. | ||
*/ | ||
def customiseInterceptors: CustomiseInterceptors[Id, NettyIdServerOptions] = { | ||
CustomiseInterceptors( | ||
createOptions = (ci: CustomiseInterceptors[Id, NettyIdServerOptions]) => default(ci.interceptors) | ||
).serverLog(defaultServerLog) | ||
} | ||
|
||
private val log = Logger[NettyIdServerInterpreter] | ||
|
||
lazy val defaultServerLog: ServerLog[Id] = { | ||
DefaultServerLog[Id]( | ||
doLogWhenReceived = debugLog(_, None), | ||
doLogWhenHandled = debugLog, | ||
doLogAllDecodeFailures = debugLog, | ||
doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex), | ||
noLog = () | ||
) | ||
} | ||
|
||
private def debugLog(msg: String, exOpt: Option[Throwable]): Unit = NettyDefaults.debugLog(log, msg, exOpt) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package sttp.tapir.server.netty | ||
|
||
import sttp.monad.MonadError | ||
|
||
package object loom { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this should be an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or ... maybe we can put the whole Then we could also share the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but maybe it's better to keep this separate, I don't know ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah let's not overcomplicate for now ;) |
||
type Id[X] = X | ||
type IdRoute = Route[Id] | ||
|
||
private[loom] implicit val idMonad: MonadError[Id] = new MonadError[Id] { | ||
override def unit[T](t: T): Id[T] = t | ||
override def map[T, T2](fa: Id[T])(f: T => T2): Id[T2] = f(fa) | ||
override def flatMap[T, T2](fa: Id[T])(f: T => Id[T2]): Id[T2] = f(fa) | ||
override def error[T](t: Throwable): Id[T] = throw t | ||
override protected def handleWrappedError[T](rt: Id[T])(h: PartialFunction[Throwable, Id[T]]): Id[T] = rt | ||
override def eval[T](t: => T): Id[T] = t | ||
override def ensure[T](f: Id[T], e: => Id[Unit]): Id[T] = | ||
try f | ||
finally e | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package sttp.tapir.server.netty.loom | ||
|
||
import cats.effect.{IO, Resource} | ||
import io.netty.channel.nio.NioEventLoopGroup | ||
import org.scalatest.EitherValues | ||
import sttp.tapir.server.netty.internal.FutureUtil.nettyFutureToScala | ||
import sttp.tapir.server.tests._ | ||
import sttp.tapir.tests.{Test, TestSuite} | ||
|
||
import scala.concurrent.Future | ||
|
||
class NettyIdServerTest extends TestSuite with EitherValues { | ||
override def tests: Resource[IO, List[Test]] = | ||
backendResource.flatMap { backend => | ||
Resource | ||
.make(IO.delay { | ||
val eventLoopGroup = new NioEventLoopGroup() | ||
|
||
val interpreter = new NettyIdTestServerInterpreter(eventLoopGroup) | ||
val createServerTest = new DefaultCreateServerTest(backend, interpreter) | ||
|
||
val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() | ||
|
||
(tests, eventLoopGroup) | ||
}) { case (_, eventLoopGroup) => | ||
IO.fromFuture(IO.delay(nettyFutureToScala(eventLoopGroup.shutdownGracefully()): Future[_])).void | ||
} | ||
.map { case (tests, _) => tests } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package sttp.tapir.server.netty.loom | ||
|
||
import cats.data.NonEmptyList | ||
import cats.effect.{IO, Resource} | ||
import io.netty.channel.nio.NioEventLoopGroup | ||
import sttp.tapir.server.ServerEndpoint | ||
import sttp.tapir.server.netty.NettyConfig | ||
import sttp.tapir.server.tests.TestServerInterpreter | ||
import sttp.tapir.tests.Port | ||
|
||
class NettyIdTestServerInterpreter(eventLoopGroup: NioEventLoopGroup) | ||
extends TestServerInterpreter[Id, Any, NettyIdServerOptions, IdRoute] { | ||
override def route(es: List[ServerEndpoint[Any, Id]], interceptors: Interceptors): IdRoute = { | ||
val serverOptions: NettyIdServerOptions = interceptors(NettyIdServerOptions.customiseInterceptors).options | ||
NettyIdServerInterpreter(serverOptions).toRoute(es) | ||
} | ||
|
||
override def server(routes: NonEmptyList[IdRoute]): Resource[IO, Port] = { | ||
val config = | ||
NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown | ||
val options = NettyIdServerOptions.default | ||
val bind = IO.blocking(NettyIdServer(options, config).addRoutes(routes.toList).start()) | ||
|
||
Resource | ||
.make(bind)(binding => IO.blocking(binding.stop())) | ||
.map(b => b.port) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package sttp.tapir.server.netty.loom | ||
|
||
import sttp.tapir._ | ||
|
||
object SleepDemo extends App { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be removed now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or ... move the "demos" to examples? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't seem to add anything more than it's in the docs, so I'll just remove it. |
||
val e = endpoint.get.in("hello").out(stringBody).serverLogicSuccess[Id] { _ => | ||
Thread.sleep(1000) | ||
"Hello, world!" | ||
} | ||
NettyIdServer().addEndpoint(e).start() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm maybe this should be called
ONLY_LOOM
or sth like that? otherwise it sounds additive