Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate servers from tapir-loom #3304

Merged
merged 17 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Import tapir-netty-loom code
  • Loading branch information
kciesielski committed Nov 7, 2023
commit e418a0eb222616e8589776b7d33096cd2d1552d3
22 changes: 21 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ lazy val rawAllAggregates = core.projectRefs ++
vertxServerZio1.projectRefs ++
jdkhttpServer.projectRefs ++
nettyServer.projectRefs ++
nettyServerLoom.projectRefs ++
nettyServerCats.projectRefs ++
nettyServerZio.projectRefs ++
zio1HttpServer.projectRefs ++
Expand Down Expand Up @@ -251,13 +252,21 @@ lazy val rawAllAggregates = core.projectRefs ++
awsCdk.projectRefs

lazy val allAggregates: Seq[ProjectReference] = {
if (sys.env.isDefinedAt("STTP_NATIVE")) {
val filteredByNative = if (sys.env.isDefinedAt("STTP_NATIVE")) {
println("[info] STTP_NATIVE defined, including native in the aggregate projects")
rawAllAggregates
} else {
println("[info] STTP_NATIVE *not* defined, *not* including native in the aggregate projects")
rawAllAggregates.filterNot(_.toString.contains("Native"))
}
if (sys.env.isDefinedAt("JDK_LOOM")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm maybe this should be called ONLY_LOOM or sth like that? otherwise it sounds additive

println("[info] JDK_LOOM defined, including loom-based projects")
filteredByNative
} else {
println("[info] JDK_LOOM *not* defined, *not* including loom-based-projects")
filteredByNative.filterNot(_.toString.contains("Loom"))
}

}

// separating testing into different Scala versions so that it's not all done at once, as it causes memory problems on CI
Expand Down Expand Up @@ -1443,6 +1452,17 @@ lazy val nettyServer: ProjectMatrix = (projectMatrix in file("server/netty-serve
.jvmPlatform(scalaVersions = scala2And3Versions)
.dependsOn(serverCore, serverTests % Test)

lazy val nettyServerLoom: ProjectMatrix =
ProjectMatrix("nettyServerLoom", file(s"server/netty-server/loom"))
.settings(commonJvmSettings)
.settings(
name := "tapir-netty-server-loom",
// needed because of https://github.com/coursier/coursier/issues/2016
useCoursier := false
)
.jvmPlatform(scalaVersions = scala2And3Versions)
.dependsOn(nettyServer, serverTests % Test)

lazy val nettyServerCats: ProjectMatrix = nettyServerProject("cats", catsEffect)
.settings(
libraryDependencies ++= Seq(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package sttp.tapir.server.netty.loom

import io.netty.channel.Channel
import io.netty.channel.EventLoopGroup
import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup}
import io.netty.channel.unix.DomainSocketAddress
import io.netty.util.concurrent.DefaultEventExecutor
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.server.netty.NettyConfig
import sttp.tapir.server.netty.NettyResponse
import sttp.tapir.server.netty.Route
import sttp.tapir.server.netty.internal.NettyBootstrap
import sttp.tapir.server.netty.internal.NettyServerHandler

import java.net.InetSocketAddress
import java.net.SocketAddress
import java.nio.file.Path
import java.nio.file.Paths
import java.util.UUID
import java.util.concurrent.Executors
import java.util.concurrent.{Future => JFuture}
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal

case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, config: NettyConfig) {
private val executor = Executors.newVirtualThreadPerTaskExecutor()

def addEndpoint(se: ServerEndpoint[Any, Id]): NettyIdServer = addEndpoints(List(se))
def addEndpoint(se: ServerEndpoint[Any, Id], overrideOptions: NettyIdServerOptions): NettyIdServer =
addEndpoints(List(se), overrideOptions)
def addEndpoints(ses: List[ServerEndpoint[Any, Id]]): NettyIdServer = addRoute(NettyIdServerInterpreter(options).toRoute(ses))
def addEndpoints(ses: List[ServerEndpoint[Any, Id]], overrideOptions: NettyIdServerOptions): NettyIdServer =
addRoute(NettyIdServerInterpreter(overrideOptions).toRoute(ses))

def addRoute(r: IdRoute): NettyIdServer = copy(routes = routes :+ r)
def addRoutes(r: Iterable[IdRoute]): NettyIdServer = copy(routes = routes ++ r)

def options(o: NettyIdServerOptions): NettyIdServer = copy(options = o)
def config(c: NettyConfig): NettyIdServer = copy(config = c)
def modifyConfig(f: NettyConfig => NettyConfig): NettyIdServer = config(f(config))

def host(hostname: String): NettyIdServer = modifyConfig(_.host(hostname))

def port(p: Int): NettyIdServer = modifyConfig(_.port(p))

def start(): NettyIdServerBinding =
startUsingSocketOverride[InetSocketAddress](None) match {
case (socket, stop) =>
NettyIdServerBinding(socket, stop)
}

def startUsingDomainSocket(path: Option[Path] = None): NettyIdDomainSocketBinding =
startUsingDomainSocket(path.getOrElse(Paths.get(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString)))

def startUsingDomainSocket(path: Path): NettyIdDomainSocketBinding =
startUsingSocketOverride(Some(new DomainSocketAddress(path.toFile))) match {
case (socket, stop) =>
NettyIdDomainSocketBinding(socket, stop)
}

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): (SA, () => Unit) = {
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
val route = Route.combine(routes)

def unsafeRunF(
callToExecute: () => Id[ServerResponse[NettyResponse]]
): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = {
val scalaPromise = Promise[ServerResponse[NettyResponse]]()
val jFuture: JFuture[?] = executor.submit(new Runnable {
override def run(): Unit = try {
val result = callToExecute()
scalaPromise.success(result)
} catch {
case NonFatal(e) => scalaPromise.failure(e)
}
})

(
scalaPromise.future,
() => {
jFuture.cancel(true)
Future.unit
}
)
}
val channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe
val isShuttingDown: AtomicBoolean = new AtomicBoolean(false)

val channelIdFuture = NettyBootstrap(
config,
new NettyServerHandler(
route,
unsafeRunF,
config.maxContentLength,
channelGroup,
isShuttingDown
),
eventLoopGroup,
socketOverride
)
channelIdFuture.await()
val channelId = channelIdFuture.channel()

(
channelId.localAddress().asInstanceOf[SA],
() => stop(channelId, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout)
)
}

private def waitForClosedChannels(
channelGroup: ChannelGroup,
startNanos: Long,
gracefulShutdownTimeoutNanos: Option[Long]
): Unit = {
while (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) {
Thread.sleep(100)
}
val _ = channelGroup.close().get()
}
private def stop(
ch: Channel,
eventLoopGroup: EventLoopGroup,
channelGroup: ChannelGroup,
isShuttingDown: AtomicBoolean,
gracefulShutdownTimeout: Option[FiniteDuration]
): Unit = {
isShuttingDown.set(true)
waitForClosedChannels(
channelGroup,
startNanos = System.nanoTime(),
gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos)
)
ch.close().get()
if (config.shutdownEventLoopGroupOnClose) {
val _ = eventLoopGroup.shutdownGracefully().get()
}
}
}

object NettyIdServer {
def apply(): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, NettyConfig.defaultNoStreaming)

def apply(serverOptions: NettyIdServerOptions): NettyIdServer =
NettyIdServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming)

def apply(config: NettyConfig): NettyIdServer =
NettyIdServer(Vector.empty, NettyIdServerOptions.default, config)

def apply(serverOptions: NettyIdServerOptions, config: NettyConfig): NettyIdServer =
NettyIdServer(Vector.empty, serverOptions, config)
}
case class NettyIdServerBinding(localSocket: InetSocketAddress, stop: () => Unit) {
def hostName: String = localSocket.getHostName
def port: Int = localSocket.getPort
}
case class NettyIdDomainSocketBinding(localSocket: DomainSocketAddress, stop: () => Unit) {
def path: String = localSocket.path()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package sttp.tapir.server.netty.loom

import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync}

trait NettyIdServerInterpreter {
def nettyServerOptions: NettyIdServerOptions

def toRoute(
ses: List[ServerEndpoint[Any, Id]]
): IdRoute = {
NettyServerInterpreter.toRoute[Id](
ses,
nettyServerOptions.interceptors,
nettyServerOptions.createFile,
nettyServerOptions.deleteFile,
new RunAsync[Id] {
override def apply[T](f: => Id[T]): Unit = {
val _ = f
()
}
}
)
}
}

object NettyIdServerInterpreter {
def apply(serverOptions: NettyIdServerOptions = NettyIdServerOptions.default): NettyIdServerInterpreter = {
new NettyIdServerInterpreter {
override def nettyServerOptions: NettyIdServerOptions = serverOptions
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package sttp.tapir.server.netty.loom

import com.typesafe.scalalogging.Logger
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interceptor.log.{DefaultServerLog, ServerLog}
import sttp.tapir.server.netty.internal.NettyDefaults
import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor}
import sttp.tapir.{Defaults, TapirFile}

case class NettyIdServerOptions(
interceptors: List[Interceptor[Id]],
createFile: ServerRequest => TapirFile,
deleteFile: TapirFile => Unit
) {
def prependInterceptor(i: Interceptor[Id]): NettyIdServerOptions = copy(interceptors = i :: interceptors)
def appendInterceptor(i: Interceptor[Id]): NettyIdServerOptions = copy(interceptors = interceptors :+ i)
}

object NettyIdServerOptions {

/** Default options, using TCP sockets (the most common case). This can be later customised using [[NettyIdServerOptions#nettyOptions()]].
*/
def default: NettyIdServerOptions = customiseInterceptors.options

private def default(
interceptors: List[Interceptor[Id]]
): NettyIdServerOptions =
NettyIdServerOptions(
interceptors,
_ => Defaults.createTempFile(),
Defaults.deleteFile()
)

/** Customise the interceptors that are being used when exposing endpoints as a server. By default uses TCP sockets (the most common
* case), but this can be later customised using [[NettyIdServerOptions#nettyOptions()]].
*/
def customiseInterceptors: CustomiseInterceptors[Id, NettyIdServerOptions] = {
CustomiseInterceptors(
createOptions = (ci: CustomiseInterceptors[Id, NettyIdServerOptions]) => default(ci.interceptors)
).serverLog(defaultServerLog)
}

private val log = Logger[NettyIdServerInterpreter]

lazy val defaultServerLog: ServerLog[Id] = {
DefaultServerLog[Id](
doLogWhenReceived = debugLog(_, None),
doLogWhenHandled = debugLog,
doLogAllDecodeFailures = debugLog,
doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex),
noLog = ()
)
}

private def debugLog(msg: String, exOpt: Option[Throwable]): Unit = NettyDefaults.debugLog(log, msg, exOpt)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package sttp.tapir.server.netty

import sttp.monad.MonadError

package object loom {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this should be an id package?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or ... maybe we can put the whole Id interpreter into the main package? only problem, we'd need to compile it using JDK 21. But maybe we can compile using JDK 21 and target 11 bytecode?

Then we could also share the Id alias across netty-loom / nima

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but maybe it's better to keep this separate, I don't know ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's not overcomplicate for now ;)

type Id[X] = X
type IdRoute = Route[Id]

private[loom] implicit val idMonad: MonadError[Id] = new MonadError[Id] {
override def unit[T](t: T): Id[T] = t
override def map[T, T2](fa: Id[T])(f: T => T2): Id[T2] = f(fa)
override def flatMap[T, T2](fa: Id[T])(f: T => Id[T2]): Id[T2] = f(fa)
override def error[T](t: Throwable): Id[T] = throw t
override protected def handleWrappedError[T](rt: Id[T])(h: PartialFunction[Throwable, Id[T]]): Id[T] = rt
override def eval[T](t: => T): Id[T] = t
override def ensure[T](f: Id[T], e: => Id[Unit]): Id[T] =
try f
finally e
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package sttp.tapir.server.netty.loom

import cats.effect.{IO, Resource}
import io.netty.channel.nio.NioEventLoopGroup
import org.scalatest.EitherValues
import sttp.tapir.server.netty.internal.FutureUtil.nettyFutureToScala
import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future

class NettyIdServerTest extends TestSuite with EitherValues {
override def tests: Resource[IO, List[Test]] =
backendResource.flatMap { backend =>
Resource
.make(IO.delay {
val eventLoopGroup = new NioEventLoopGroup()

val interpreter = new NettyIdTestServerInterpreter(eventLoopGroup)
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests()

(tests, eventLoopGroup)
}) { case (_, eventLoopGroup) =>
IO.fromFuture(IO.delay(nettyFutureToScala(eventLoopGroup.shutdownGracefully()): Future[_])).void
}
.map { case (tests, _) => tests }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package sttp.tapir.server.netty.loom

import cats.data.NonEmptyList
import cats.effect.{IO, Resource}
import io.netty.channel.nio.NioEventLoopGroup
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.NettyConfig
import sttp.tapir.server.tests.TestServerInterpreter
import sttp.tapir.tests.Port

class NettyIdTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)
extends TestServerInterpreter[Id, Any, NettyIdServerOptions, IdRoute] {
override def route(es: List[ServerEndpoint[Any, Id]], interceptors: Interceptors): IdRoute = {
val serverOptions: NettyIdServerOptions = interceptors(NettyIdServerOptions.customiseInterceptors).options
NettyIdServerInterpreter(serverOptions).toRoute(es)
}

override def server(routes: NonEmptyList[IdRoute]): Resource[IO, Port] = {
val config =
NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown
val options = NettyIdServerOptions.default
val bind = IO.blocking(NettyIdServer(options, config).addRoutes(routes.toList).start())

Resource
.make(bind)(binding => IO.blocking(binding.stop()))
.map(b => b.port)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package sttp.tapir.server.netty.loom

import sttp.tapir._

object SleepDemo extends App {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or ... move the "demos" to examples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to add anything more than it's in the docs, so I'll just remove it.

val e = endpoint.get.in("hello").out(stringBody).serverLogicSuccess[Id] { _ =>
Thread.sleep(1000)
"Hello, world!"
}
NettyIdServer().addEndpoint(e).start()
}