Skip to content

Commit

Permalink
codegen: Support enums in paths (softwaremill#3889)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson committed Jul 2, 2024
1 parent c262ea5 commit ae26fc7
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ object BasicGenerator {
JsonSerdeLib.Circe
}

val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
doc = doc,
targetScala3 = targetScala3,
queryParamRefs = queryParamRefs,
queryOrPathParamRefs = queryOrPathParamRefs,
jsonSerdeLib = normalisedJsonLib,
jsonParamRefs = jsonParamRefs,
fullModelPath = s"$packagePath.$objName",
Expand Down Expand Up @@ -146,15 +146,18 @@ object BasicGenerator {
"""
|case class CommaSeparatedValues[T](values: List[T])
|case class ExplodedValues[T](values: List[T])
|trait QueryParamSupport[T] {
|trait ExtraParamSupport[T] {
| def decode(s: String): sttp.tapir.DecodeResult[T]
| def encode(t: T): String
|}
|implicit def makeQueryCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
|implicit def makePathCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[String, T, sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.string.mapDecode(support.decode)(support.encode)
|}
|implicit def makeQueryCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(support.decode)(support.encode)
|}
|implicit def makeQueryOptCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
|implicit def makeQueryOptCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
|}
Expand All @@ -169,10 +172,6 @@ object BasicGenerator {
| case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
| }(_.map(_.values.map(support.encode).mkString(",")))
|}
|implicit def makeExplodedQuerySeqCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.list[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(values => DecodeResult.sequence(values.map(support.decode)).map(s => ExplodedValues(s.toList)))(_.values.map(support.encode))
|}
|implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ClassDefinitionGenerator {
def classDefs(
doc: OpenapiDocument,
targetScala3: Boolean = false,
queryParamRefs: Set[String] = Set.empty,
queryOrPathParamRefs: Set[String] = Set.empty,
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib = JsonSerdeLib.Circe,
jsonParamRefs: Set[String] = Set.empty,
fullModelPath: String = "",
Expand All @@ -25,10 +25,10 @@ class ClassDefinitionGenerator {
val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap
val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq
val adtInheritanceMap: Map[String, Seq[String]] = mkMapParentsByChild(allOneOfSchemas)
val generatesQueryParamEnums = enumsDefinedOnEndpointParams ||
val generatesQueryOrPathParamEnums = enumsDefinedOnEndpointParams ||
allSchemas
.collect { case (name, _: OpenapiSchemaEnum) => name }
.exists(queryParamRefs.contains)
.exists(queryOrPathParamRefs.contains)

def fetchJsonParamRefs(initialSet: Set[String], toCheck: Seq[OpenapiSchemaType]): Set[String] = toCheck match {
case Nil => initialSet
Expand All @@ -41,7 +41,7 @@ class ClassDefinitionGenerator {
)

val adtTypes = adtInheritanceMap.flatMap(_._2).toSeq.distinct.map(name => s"sealed trait $name").mkString("", "\n", "\n")
val enumQuerySerdeHelper = if (!generatesQueryParamEnums) "" else enumQuerySerdeHelperDefn(targetScala3)
val enumSerdeHelper = if (!generatesQueryOrPathParamEnums) "" else enumSerdeHelperDefn(targetScala3)
val schemas = SchemaGenerator.generateSchemas(doc, allSchemas, fullModelPath, jsonSerdeLib, maxSchemasPerFile)
val jsonSerdes = JsonSerdeGenerator.serdeDefs(
doc,
Expand All @@ -58,13 +58,13 @@ class ClassDefinitionGenerator {
case (name, obj: OpenapiSchemaObject) =>
generateClass(allSchemas, name, obj, allTransitiveJsonParamRefs, adtInheritanceMap, jsonSerdeLib, targetScala3)
case (name, obj: OpenapiSchemaEnum) =>
EnumGenerator.generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
EnumGenerator.generateEnum(name, obj, targetScala3, queryOrPathParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (_, _: OpenapiSchemaOneOf) => Nil
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
})
.map(_.mkString("\n"))
val helpers = (enumQuerySerdeHelper + adtTypes).linesIterator
val helpers = (enumSerdeHelper + adtTypes).linesIterator
.filterNot(_.forall(_.isWhitespace))
.mkString("\n")
// Json serdes & schemas live in separate files from the class defns
Expand Down Expand Up @@ -97,14 +97,14 @@ class ClassDefinitionGenerator {
.groupBy(_._1)
.mapValues(_.map(_._2))

private def enumQuerySerdeHelperDefn(targetScala3: Boolean): String = {
private def enumSerdeHelperDefn(targetScala3: Boolean): String = {
if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|case class EnumQueryParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends QueryParamSupport[T] {
|case class EnumExtraParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends ExtraParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util
Expand All @@ -121,12 +121,12 @@ class ClassDefinitionGenerator {
| )
| def encode(t: T): String = t.name
|}
|def queryCodecSupport[T: enumextensions.EnumMirror]: QueryParamSupport[T] =
| EnumQueryParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|def extraCodecSupport[T: enumextensions.EnumMirror]: ExtraParamSupport[T] =
| EnumExtraParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|""".stripMargin
else
"""
|case class EnumQueryParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends QueryParamSupport[T] {
|case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
Expand All @@ -142,8 +142,8 @@ class ClassDefinitionGenerator {
| )
| def encode(t: T): String = t.entryName
|}
|def queryCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): QueryParamSupport[T] =
| EnumQueryParamSupport(enumName, T)
|def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] =
| EnumExtraParamSupport(enumName, T)
|""".stripMargin
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ case class GeneratedEndpoints(
}
case class EndpointDefs(
endpointDecls: Map[Option[String], String],
queryParamRefs: Set[String],
queryOrPathParamRefs: Set[String],
jsonParamRefs: Set[String],
enumsDefinedOnEndpointParams: Boolean
)
Expand All @@ -58,7 +58,7 @@ class EndpointGenerator {
jsonSerdeLib: JsonSerdeLib
): EndpointDefs = {
val components = Option(doc.components).flatten
val GeneratedEndpoints(endpointsByFile, queryParamRefs, jsonParamRefs, definesEnumQueryParam) =
val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) =
doc.paths
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib))
.foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _)
Expand All @@ -77,7 +77,7 @@ class EndpointGenerator {
|$allEP
|""".stripMargin
}.toMap
EndpointDefs(endpointDecls, queryParamRefs, jsonParamRefs, definesEnumQueryParam)
EndpointDefs(endpointDecls, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam)
}

private[codegen] def generatedEndpoints(
Expand Down Expand Up @@ -119,8 +119,8 @@ class EndpointGenerator {
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")

val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema }
val queryOrPathParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" || queryParam.in == "path" => queryParam.schema }
.collect { case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped }
.toSet
val jsonParamRefs = (m.requestBody.toSeq.flatMap(_.content.map(c => (c.contentType, c.schema))) ++
Expand All @@ -143,7 +143,7 @@ class EndpointGenerator {
.toSet
(
(maybeTargetFileName, GeneratedEndpoint(name, definition, maybeLocalEnums)),
(queryParamRefs, jsonParamRefs),
(queryOrPathParamRefs, jsonParamRefs),
maybeLocalEnums.isDefined
)
}
Expand Down Expand Up @@ -215,12 +215,12 @@ class EndpointGenerator {
)(implicit location: Location): (String, Option[String]) = {
def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = {
val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize
val queryParamRefs = if (param.in == "query") Set(enumName) else Set.empty[String]
val enumParamRefs = if (param.in == "query" || param.in == "path") Set(enumName) else Set.empty[String]
val enumDefn = EnumGenerator.generateEnum(
enumName,
e,
targetScala3,
queryParamRefs,
enumParamRefs,
jsonSerdeLib,
Set.empty
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ object EnumGenerator {
val maybeCompanion =
if (queryParamRefs contains name) {
def helperImpls =
s""" given enumCodecSupport${name.capitalize}: QueryParamSupport[$name] =
| queryCodecSupport[$name]""".stripMargin
s""" given enumCodecSupport${name.capitalize}: ExtraParamSupport[$name] =
| extraCodecSupport[$name]""".stripMargin
s"""
|object $name {
|$helperImpls
Expand Down Expand Up @@ -52,8 +52,8 @@ object EnumGenerator {
val maybeQueryCodecDefn =
if (queryParamRefs contains name) {
s"""
| implicit val enumCodecSupport${name.capitalize}: QueryParamSupport[$name] =
| queryCodecSupport[$name]("${name}", ${name})""".stripMargin
| implicit val enumCodecSupport${name.capitalize}: ExtraParamSupport[$name] =
| extraCodecSupport[$name]("${name}", ${name})""".stripMargin
} else ""
s"""
|sealed trait $name extends enumeratum.EnumEntry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
.classDefs(doc, true, jsonParamRefs = Set("Test"))
.map(concatted)
val resWithQueryParamCodec = gen
.classDefs(doc, true, queryParamRefs = Set("Test"), jsonParamRefs = Set("Test"))
.classDefs(doc, true, queryOrPathParamRefs = Set("Test"), jsonParamRefs = Set("Test"))
.map(concatted)
// can't just check whether these compile, because our tests only run on scala 2.12 - so instead just eyeball it...
res shouldBe Some("""enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec {
Expand All @@ -304,7 +304,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|case class EnumQueryParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends QueryParamSupport[T] {
|case class EnumExtraParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends ExtraParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util
Expand All @@ -321,11 +321,11 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
| )
| def encode(t: T): String = t.name
|}
|def queryCodecSupport[T: enumextensions.EnumMirror]: QueryParamSupport[T] =
| EnumQueryParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|def extraCodecSupport[T: enumextensions.EnumMirror]: ExtraParamSupport[T] =
| EnumExtraParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|object Test {
| given enumCodecSupportTest: QueryParamSupport[Test] =
| queryCodecSupport[Test]
| given enumCodecSupportTest: ExtraParamSupport[Test] =
| extraCodecSupport[Test]
|}
|enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror {
| case enum1, enum2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,15 @@ object TestHelpers {
| $ref: '#/components/schemas/Test'
| post:
| responses: {}
| /pathTest/{test2}:
| parameters:
| - name: test2
| in: path
| required: true
| schema:
| $ref: '#/components/schemas/Test2'
| post:
| responses: {}
|
|components:
| schemas:
Expand All @@ -535,6 +544,12 @@ object TestHelpers {
| enum:
| - paperback
| - hardback
| Test2:
| title: Test
| type: string
| enum:
| - paperback
| - hardback
|""".stripMargin

val enumQueryParamDocs = OpenapiDocument(
Expand All @@ -555,7 +570,24 @@ object TestHelpers {
)
),
parameters = Seq(
Resolved(OpenapiParameter("test", "query", None, None, OpenapiSchemaRef("#/components/schemas/Test")))
Resolved(OpenapiParameter("test", "query", Some(false), None, OpenapiSchemaRef("#/components/schemas/Test")))
)
),
OpenapiPath(
"/pathTest/{test2}",
Seq(
OpenapiPathMethod(
methodType = "post",
parameters = Seq(),
responses = Seq(),
requestBody = None,
summary = None,
tags = None,
operationId = None
)
),
parameters = Seq(
Resolved(OpenapiParameter("test2", "path", Some(true), None, OpenapiSchemaRef("#/components/schemas/Test2")))
)
)
),
Expand All @@ -566,6 +598,11 @@ object TestHelpers {
"string",
Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")),
false
),
"Test2" -> OpenapiSchemaEnum(
"string",
Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")),
false
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ class ModelParserSpec extends AnyFlatSpec with Matchers with Checkers {
res shouldBe Right(
OpenapiSchemaEnum("string", Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")), false)
)
parser
.parse(TestHelpers.enumQueryParamYaml)
.leftMap(err => err: Error)
.flatMap(_.as[OpenapiDocument]) shouldBe Right(
TestHelpers.enumQueryParamDocs
)
}

it should "parse endpoint with defaults" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ object TapirGeneratedEndpoints {

case class CommaSeparatedValues[T](values: List[T])
case class ExplodedValues[T](values: List[T])
trait QueryParamSupport[T] {
trait ExtraParamSupport[T] {
def decode(s: String): sttp.tapir.DecodeResult[T]
def encode(t: T): String
}
implicit def makeQueryCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
implicit def makePathCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[String, T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.string.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryOptCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
implicit def makeQueryOptCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
}
Expand All @@ -38,10 +41,6 @@ object TapirGeneratedEndpoints {
case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
}(_.map(_.values.map(support.encode).mkString(",")))
}
implicit def makeExplodedQuerySeqCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.list[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(values => DecodeResult.sequence(values.map(support.decode)).map(s => ExplodedValues(s.toList)))(_.values.map(support.encode))
}
implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}
Expand Down
Loading

0 comments on commit ae26fc7

Please sign in to comment.