Skip to content

Commit

Permalink
Some codegen improvements (softwaremill#3090)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson committed Aug 7, 2023
1 parent dd57469 commit 425d65c
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object BasicGenerator {
("Boolean", nb)
case OpenapiSchemaRef(t) =>
(t.split('/').last, false)
case _ => throw new NotImplementedError("Not all simple types supported!")
case x => throw new NotImplementedError(s"Not all simple types supported! Found $x")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@ class ClassDefinitionGenerator {
generateClass(name, obj)
case (name, obj: OpenapiSchemaEnum) =>
generateEnum(name, obj, targetScala3)
case _ => throw new NotImplementedError("Only objects and enums supported!")
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
})
.map(_.mkString("\n"))
}

private[codegen] def generateMap(name: String, valueSchema: OpenapiSchemaType): Seq[String] = {
val valueSchemaName = valueSchema match {
case simpleType: OpenapiSchemaSimpleType => BasicGenerator.mapSchemaSimpleTypeToType(simpleType)._1
case otherType => throw new NotImplementedError(s"Only simple value types and refs are implemented for named maps (found $otherType)")
}
Seq(s"""type $name = Map[String, $valueSchemaName]""")
}

// Uses enumeratum for scala 2, but generates scala 3 enums instead where it can
private[codegen] def generateEnum(name: String, obj: OpenapiSchemaEnum, targetScala3: Boolean): Seq[String] = if (targetScala3) {
s"""enum $name {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class EndpointGenerator {
val ge = doc.paths.flatMap(generatedEndpoints(ps))
val definitions = ge
.map { case (name, definition) =>
s"""|val $name =
s"""|lazy val $name =
|${indent(2)(definition)}
|""".stripMargin
}
.mkString("\n")
val allEP = s"val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
|
Expand Down Expand Up @@ -82,11 +82,11 @@ class EndpointGenerator {
val (t, _) = mapSchemaSimpleTypeToType(st)
val desc = param.description.fold("")(d => s""".description("$d")""")
s""".in(${param.in}[$t]("${param.name}")$desc)"""
case _ => throw new NotImplementedError("Can't create non-simple params to input")
case x => throw new NotImplementedError(s"Can't create non-simple params to input - found $x")
}
}

val rqBody = requestBody.flatMap{ b =>
val rqBody = requestBody.flatMap { b =>
if (b.content.isEmpty) None
else if (b.content.size != 1) throw new NotImplementedError("We can handle only one requestBody content!")
else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, b.required)})")
Expand Down Expand Up @@ -136,11 +136,11 @@ class EndpointGenerator {
case OpenapiSchemaArray(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"List[$t]"
case _ => throw new NotImplementedError("Can't create non-simple or array params as output")
case x => throw new NotImplementedError(s"Can't create non-simple or array params as output (found $x)")
}
val req = if (required) outT else s"Option[$outT]"
s"jsonBody[$req]"
case _ => throw new NotImplementedError("We only handle json and text!")
case x => throw new NotImplementedError(s"We only handle json and text! Found $x")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object OpenapiComponent {
implicit val OpenapiComponentDecoder: Decoder[OpenapiComponent] = { (c: HCursor) =>
for {
schemas <- c.downField("schemas").as[Map[String, OpenapiSchemaType]]
parameters <- c.downField("parameters").as[Map[String, OpenapiParameter]].orElse(Right(Map.empty[String, OpenapiParameter]))
parameters <- c.downField("parameters").as[Option[Map[String, OpenapiParameter]]].map(_.getOrElse(Map.empty))
} yield {
OpenapiComponent(schemas, parameters.map { case (k, v) => s"#/components/parameters/$k" -> v })
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ object OpenapiModels {
for {
parameters <- c
.downField("parameters")
.as[Seq[Resolvable[OpenapiParameter]]]
.orElse(Right(List.empty[Resolvable[OpenapiParameter]]))
.as[Option[Seq[Resolvable[OpenapiParameter]]]]
.map(_.getOrElse(Nil))
responses <- c.downField("responses").as[Seq[OpenapiResponse]]
requestBody <- c.downField("requestBody").as[Option[OpenapiRequestBody]]
summary <- c.downField("summary").as[Option[String]]
Expand All @@ -183,7 +183,10 @@ object OpenapiModels {

implicit val PartialOpenapiPathDecoder: Decoder[OpenapiPath] = { (c: HCursor) =>
for {
parameters <- c.downField("parameters").as[Seq[Resolvable[OpenapiParameter]]].orElse(Right(List.empty[Resolvable[OpenapiParameter]]))
parameters <- c
.downField("parameters")
.as[Option[Seq[Resolvable[OpenapiParameter]]]]
.map(_.getOrElse(Nil))
methods <- List("get", "put", "post", "delete", "options", "head", "patch", "patch", "connect")
.traverse(method => c.downField(method).as[Option[OpenapiPathMethod]].map(_.map(_.copy(methodType = method))))
} yield OpenapiPath("--partial--", methods.flatten, parameters)
Expand All @@ -202,7 +205,7 @@ object OpenapiModels {
openapi <- c.downField("openapi").as[String]
info <- c.downField("info").as[OpenapiInfo]
paths <- c.downField("paths").as[Seq[OpenapiPath]]
components <- c.downField("components").as[Option[OpenapiComponent]].orElse(Right(None))
components <- c.downField("components").as[Option[OpenapiComponent]]
} yield OpenapiDocument(openapi, info, paths, components)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaEnum,
OpenapiSchemaMap,
OpenapiSchemaObject,
OpenapiSchemaRef,
OpenapiSchemaString
}
import sttp.tapir.codegen.testutils.CompileCheckTestBase
Expand Down Expand Up @@ -258,9 +259,31 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
val gen = new ClassDefinitionGenerator()
val res = gen.classDefs(doc, true)
// can't just check whether this compiles, because our tests only run on scala 2.12 - so instead just eyeball it...
res shouldBe Some(
"""enum Test {
res shouldBe Some("""enum Test {
| case enum1, enum2
|}""".stripMargin)
}

it should "generate named maps" in {
val doc = OpenapiDocument(
"",
null,
null,
Some(
OpenapiComponent(
Map(
"MyObject" -> OpenapiSchemaObject(Map("text" -> OpenapiSchemaString(true)), Seq("text"), false),
"MyEnum" -> OpenapiSchemaEnum(Seq(OpenapiSchemaConstantString("enum1"), OpenapiSchemaConstantString("enum2")), false),
"MyMapPrimitive" -> OpenapiSchemaMap(OpenapiSchemaString(false), false),
"MyMapObject" -> OpenapiSchemaMap(OpenapiSchemaRef("#/components/schemas/MyObject"), false),
"MyMapEnum" -> OpenapiSchemaMap(OpenapiSchemaRef("#/components/schemas/MyEnum"), false)
)
)
)
)

val gen = new ClassDefinitionGenerator()
val res = gen.classDefs(doc, false)
"import enumeratum._;" + res.get shouldCompile ()
}
}

0 comments on commit 425d65c

Please sign in to comment.