Skip to content

Commit

Permalink
alternative approach to param resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson committed Jul 27, 2023
1 parent 6cf8ede commit 789a71d
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class EndpointGenerator {
private[codegen] def allEndpoints: String = "generatedEndpoints"

def endpointDefs(doc: OpenapiDocument): String = {
val ge = doc.paths.flatMap(generatedEndpoints)
val ps = Option(doc.components).flatten.map(_.parameters) getOrElse Map.empty
val ge = doc.paths.flatMap(generatedEndpoints(ps))
val definitions = ge
.map { case (name, definition) =>
s"""|val $name =
Expand All @@ -26,13 +27,13 @@ class EndpointGenerator {
|""".stripMargin
}

private[codegen] def generatedEndpoints(p: OpenapiPath): Seq[(String, String)] = {
p.methods.map(_ withParentParameters p.parameters).map { m =>
private[codegen] def generatedEndpoints(parameters: Map[String, OpenapiParameter])(p: OpenapiPath): Seq[(String, String)] = {
p.methods.map(_.withResolvedParentParameters(parameters, p.parameters)).map { m =>
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.parameters)}
|${indent(2)(ins(m.parameters, m.requestBody))}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|""".stripMargin
Expand All @@ -54,7 +55,7 @@ class EndpointGenerator {
if (segment.startsWith("{")) {
val name = segment.drop(1).dropRight(1)
val param = parameters.find(_.name == name)
param.fold(throw new Error("URLParam not found!")) { p =>
param.fold(throw new Error(s"URLParam $name not found!")) { p =>
p.schema match {
case st: OpenapiSchemaSimpleType =>
val (t, _) = mapSchemaSimpleTypeToType(st)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
package sttp.tapir.codegen.openapi.models

case class OpenapiComponent(schemas: Map[String, OpenapiSchemaType])
import cats.syntax.either._

import OpenapiModels.OpenapiParameter

case class OpenapiComponent(
schemas: Map[String, OpenapiSchemaType],
parameters: Map[String, OpenapiParameter] = Map.empty
)

object OpenapiComponent {
import io.circe._

implicit val OpenapiComponentDecoder: Decoder[OpenapiComponent] = { (c: HCursor) =>
for {
schemas <- c.downField("schemas").as[Map[String, OpenapiSchemaType]]
parameters <- c.downField("parameters").as[Map[String, OpenapiParameter]].orElse(Right(Map.empty[String, OpenapiParameter]))
} yield {
OpenapiComponent(schemas)
OpenapiComponent(schemas, parameters.map { case (k, v) => s"#/components/parameters/$k" -> v })
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,21 @@ package sttp.tapir.codegen.openapi.models
import cats.implicits.toTraverseOps
import cats.syntax.either._

import OpenapiSchemaType.OpenapiSchemaRef
// https://swagger.io/specification/
object OpenapiModels {

sealed trait Resolvable[T] {
def resolve(input: Map[String, T]): T
def toResolved(input: Map[String, T]): Resolved[T] = Resolved(resolve(input))
}
case class Resolved[T](t: T) extends Resolvable[T] {
override def resolve(input: Map[String, T]): T = t
}
case class Ref[T](name: String) extends Resolvable[T] {
override def resolve(input: Map[String, T]): T = input.getOrElse(name, throw new IllegalArgumentException(s"Cannot resolve $name"))
}

case class OpenapiDocument(
openapi: String,
// not used so not parsed; servers, contact, license, termsOfService
Expand All @@ -23,20 +35,32 @@ object OpenapiModels {
case class OpenapiPath(
url: String,
methods: Seq[OpenapiPathMethod],
parameters: Seq[OpenapiParameter] = Nil
parameters: Seq[Resolvable[OpenapiParameter]] = Nil
)

case class OpenapiPathMethod(
methodType: String,
parameters: Seq[OpenapiParameter],
parameters: Seq[Resolvable[OpenapiParameter]],
responses: Seq[OpenapiResponse],
requestBody: Option[OpenapiRequestBody],
summary: Option[String] = None,
tags: Option[Seq[String]] = None,
operationId: Option[String] = None
) {
def withParentParameters(pathParameters: Seq[OpenapiParameter]): OpenapiPathMethod =
this.copy(parameters = pathParameters.filterNot(p => parameters.exists(p.name == _.name)) ++ parameters)
def resolvedParameters: Seq[OpenapiParameter] = parameters.collect { case Resolved(t) => t }
def withResolvedParentParameters(
pMap: Map[String, OpenapiParameter],
pathParameters: Seq[Resolvable[OpenapiParameter]]
): OpenapiPathMethod = {
val resolved = parameters.map(_.toResolved(pMap))
val duplicates = resolved.groupBy(_.t.name).filter(_._2.size > 1).keys
if (duplicates.nonEmpty) throw new IllegalArgumentException(s"Duplicate parameters ${duplicates.mkString(", ")}")
val filteredParents: Seq[Resolved[OpenapiParameter]] =
pathParameters.map(_.toResolved(pMap)).filterNot(p => resolved.exists(p.t.name == _.t.name))
val parentDuplicates = filteredParents.groupBy(_.t.name).filter(_._2.size > 1).keys
if (parentDuplicates.nonEmpty) throw new IllegalArgumentException(s"Duplicate parameters ${parentDuplicates.mkString(", ")}")
this.copy(parameters = filteredParents ++ resolved)
}
}

case class OpenapiParameter(
Expand Down Expand Up @@ -138,9 +162,15 @@ object OpenapiModels {

implicit val OpenapiInfoDecoder: Decoder[OpenapiInfo] = deriveDecoder[OpenapiInfo]
implicit val OpenapiParameterDecoder: Decoder[OpenapiParameter] = deriveDecoder[OpenapiParameter]
implicit def ResolvableDecoder[T: Decoder]: Decoder[Resolvable[T]] = { (c: HCursor) =>
c.as[T].map(Resolved(_)).orElse(c.as[OpenapiSchemaRef].map(r => Ref(r.name)))
}
implicit val PartialOpenapiPathMethodDecoder: Decoder[OpenapiPathMethod] = { (c: HCursor) =>
for {
parameters <- c.downField("parameters").as[Seq[OpenapiParameter]].orElse(Right(List.empty[OpenapiParameter]))
parameters <- c
.downField("parameters")
.as[Seq[Resolvable[OpenapiParameter]]]
.orElse(Right(List.empty[Resolvable[OpenapiParameter]]))
responses <- c.downField("responses").as[Seq[OpenapiResponse]]
requestBody <- c.downField("requestBody").as[Option[OpenapiRequestBody]]
summary <- c.downField("summary").as[Option[String]]
Expand All @@ -153,7 +183,7 @@ object OpenapiModels {

implicit val PartialOpenapiPathDecoder: Decoder[OpenapiPath] = { (c: HCursor) =>
for {
parameters <- c.downField("parameters").as[Seq[OpenapiParameter]].orElse(Right(List.empty[OpenapiParameter]))
parameters <- c.downField("parameters").as[Seq[Resolvable[OpenapiParameter]]].orElse(Right(List.empty[Resolvable[OpenapiParameter]]))
methods <- List("get", "put", "post", "delete", "options", "head", "patch", "patch", "connect")
.traverse(method => c.downField(method).as[Option[OpenapiPathMethod]].map(_.map(_.copy(methodType = method))))
} yield OpenapiPath("--partial--", methods.flatten, parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiPath,
OpenapiPathMethod,
OpenapiResponse,
OpenapiResponseContent
OpenapiResponseContent,
Resolved
}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{OpenapiSchemaArray, OpenapiSchemaString}
import sttp.tapir.codegen.testutils.CompileCheckTestBase
Expand All @@ -23,7 +24,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
Seq(
OpenapiPathMethod(
methodType = "get",
parameters = Seq(OpenapiParameter("asd-id", "path", true, None, OpenapiSchemaString(false))),
parameters = Seq(Resolved(OpenapiParameter("asd-id", "path", true, None, OpenapiSchemaString(false)))),
responses = Seq(
OpenapiResponse(
"200",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiRequestBody,
OpenapiRequestBodyContent,
OpenapiResponse,
OpenapiResponseContent
OpenapiResponseContent,
Ref,
Resolved
}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
Expand All @@ -37,6 +39,7 @@ object TestHelpers {
| required: true
| schema:
| type: string
| - $ref: '#/components/parameters/year'
| post:
| operationId: postBooksGenreYear
| parameters:
Expand Down Expand Up @@ -82,11 +85,7 @@ object TestHelpers {
| required: true
| schema:
| type: string
| - name: year
| in: path
| required: true
| schema:
| type: integer
| - $ref: '#/components/parameters/offset'
| - name: limit
| in: query
| description: Maximum number of books to retrieve
Expand Down Expand Up @@ -124,6 +123,20 @@ object TestHelpers {
| properties:
| title:
| type: string
| parameters:
| offset:
| name: offset
| in: query
| description: Offset at which to start fetching books
| required: true
| schema:
| type: integer
| year:
| name: year
| in: path
| required: true
| schema:
| type: integer
|""".stripMargin

val myBookshopDoc = OpenapiDocument(
Expand All @@ -136,10 +149,10 @@ object TestHelpers {
OpenapiPathMethod(
methodType = "get",
parameters = Seq(
OpenapiParameter("genre", "path", true, None, OpenapiSchemaString(false)),
OpenapiParameter("year", "path", true, None, OpenapiSchemaInt(false)),
OpenapiParameter("limit", "query", true, Some("Maximum number of books to retrieve"), OpenapiSchemaInt(false)),
OpenapiParameter("X-Auth-Token", "header", true, None, OpenapiSchemaString(false))
Resolved(OpenapiParameter("genre", "path", true, None, OpenapiSchemaString(false))),
Ref[OpenapiParameter]("#/components/parameters/offset"),
Resolved(OpenapiParameter("limit", "query", true, Some("Maximum number of books to retrieve"), OpenapiSchemaInt(false))),
Resolved(OpenapiParameter("X-Auth-Token", "header", true, None, OpenapiSchemaString(false)))
),
responses = Seq(
OpenapiResponse(
Expand All @@ -157,9 +170,9 @@ object TestHelpers {
OpenapiPathMethod(
methodType = "post",
parameters = Seq(
OpenapiParameter("year", "path", true, None, OpenapiSchemaInt(false)),
OpenapiParameter("limit", "query", true, Some("Maximum number of books to retrieve"), OpenapiSchemaInt(false)),
OpenapiParameter("X-Auth-Token", "header", true, None, OpenapiSchemaString(false))
Resolved(OpenapiParameter("year", "path", true, None, OpenapiSchemaInt(false))),
Resolved(OpenapiParameter("limit", "query", true, Some("Maximum number of books to retrieve"), OpenapiSchemaInt(false))),
Resolved(OpenapiParameter("X-Auth-Token", "header", true, None, OpenapiSchemaString(false)))
),
responses = Seq(
OpenapiResponse(
Expand All @@ -185,13 +198,21 @@ object TestHelpers {
operationId = Some("postBooksGenreYear")
)
),
parameters = Seq(OpenapiParameter("genre", "path", true, None, OpenapiSchemaString(false)))
parameters = Seq(
Resolved(OpenapiParameter("genre", "path", true, None, OpenapiSchemaString(false))),
Ref("#/components/parameters/year")
)
)
),
Some(
OpenapiComponent(
Map(
"Book" -> OpenapiSchemaObject(Map("title" -> OpenapiSchemaString(false)), Seq("title"), false)
),
Map(
"#/components/parameters/offset" ->
OpenapiParameter("offset", "query", true, Some("Offset at which to start fetching books"), OpenapiSchemaInt(false)),
"#/components/parameters/year" -> OpenapiParameter("year", "path", true, None, OpenapiSchemaInt(false))
)
)
)
Expand Down Expand Up @@ -273,7 +294,7 @@ object TestHelpers {
OpenapiPathMethod(
methodType = "get",
parameters = Seq(
OpenapiParameter("name", "query", true, None, OpenapiSchemaString(false))
Resolved(OpenapiParameter("name", "query", true, None, OpenapiSchemaString(false)))
),
responses = Seq(
OpenapiResponse(
Expand Down Expand Up @@ -374,7 +395,7 @@ object TestHelpers {
OpenapiPathMethod(
methodType = "get",
Seq(
OpenapiParameter("name", "path", true, None, OpenapiSchemaString(false))
Resolved(OpenapiParameter("name", "path", true, None, OpenapiSchemaString(false)))
),
responses = Seq(
OpenapiResponse(
Expand Down

0 comments on commit 789a71d

Please sign in to comment.