Skip to content

Commit

Permalink
Refactor handling default discriminator in writers
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski committed Oct 5, 2023
1 parent 6368534 commit 7088450
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,16 @@ object Pickler:
macroProductW[T](
schema,
childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.writer).productIterator.toList,
childDefaults
childDefaults,
config
)
override lazy val reader: Reader[T] =
macroProductR[T](schema, childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader), childDefaults)(
using product
macroProductR[T](
schema,
childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader),
childDefaults,
product
)

}
Pickler[T](tapirPickle, schema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper
LeafWrapper(new TaggedReader.Leaf[V](n, rw), rw, n)
}

inline def macroProductR[T](schema: Schema[T], childReaders: Tuple, childDefaults: List[Option[Any]])(using
m: Mirror.ProductOf[T]
): Reader[T] =
inline def macroProductR[T](schema: Schema[T], childReaders: Tuple, childDefaults: List[Option[Any]], m: Mirror.ProductOf[T]): Reader[T] =
val schemaFields = schema.schemaType.asInstanceOf[SchemaType.SProduct[T]].fields

val reader = new CaseClassReadereader[T](upickleMacros.paramsCount[T], upickleMacros.checkErrorMissingKeysCount[T]()) {
Expand Down Expand Up @@ -69,11 +67,14 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper

new TaggedReader.Node[T](readersFromMapping.asInstanceOf[Seq[TaggedReader[T]]]: _*)

case DefaultSubtypeDiscriminator[T](_, toValue) =>
case discriminator: DefaultSubtypeDiscriminator[T] =>
val readers = childPicklers.map(cp => {
(cp.schema.name, cp.innerUpickle.reader) match {
case (Some(sName), wrappedReader: Readers#LeafWrapper[_]) =>
TaggedReader.Leaf[T](toValue(sName), wrappedReader.r.asInstanceOf[Reader[T]])
TaggedReader.Leaf[T](
discriminator.toValue(sName),
wrappedReader.r.asInstanceOf[Reader[T]]
)
case _ =>
cp.innerUpickle.reader.asInstanceOf[Reader[T]]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ private[pickler] trait Writers extends WritersVersionSpecific with UpickleHelper
inline def macroProductW[T: ClassTag](
schema: Schema[T],
childWriters: => List[Any],
childDefaults: => List[Option[Any]]
)(using
Configuration
childDefaults: => List[Option[Any]],
config: PicklerConfiguration
) =
lazy val writer = new CaseClassWriter[T] {
def length(v: T) = upickleMacros.writeLength[T](outerThis, v)
Expand Down Expand Up @@ -69,7 +68,7 @@ private[pickler] trait Writers extends WritersVersionSpecific with UpickleHelper
inline if upickleMacros.isMemberOfSealedHierarchy[T] && !isEnumeration[T] then
annotate[T](
writer,
upickleMacros.tagName[T],
schema.name.map(config.toDiscriminatorValue).getOrElse(upickleMacros.tagName[T]),
Annotator.Checker.Cls(implicitly[ClassTag[T]].runtimeClass)
) // tagName is responsible for extracting the @tag annotation meaning the discriminator value
else if upickleMacros.isSingleton[T]
Expand All @@ -91,11 +90,8 @@ private[pickler] trait Writers extends WritersVersionSpecific with UpickleHelper
val (tag, w) = super.findWriter(v)
val overriddenTag = discriminator.writeUnsafe(v) // here we use our discirminator instead of uPickle's
(overriddenTag, w)

case DefaultSubtypeDiscriminator[T](_, toValue) =>
val (t, writer) = super.findWriter(v)
val t2 = toValue(SName(t, Nil)) // TODO
(t2, writer)
case _ =>
super.findWriter(v)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class PicklerCoproductTest extends AnyFlatSpec with Matchers {
decoded shouldBe Value(inputObject)
}

it should "use custom discriminator name function" in {
it should "use custom discriminator value function" in {
// given
import generic.auto.* // for Pickler auto-derivation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sttp.tapir.json.pickler
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import sttp.tapir.DecodeResult.Value
import sttp.tapir.generic.Configuration
import sttp.tapir.{Schema, SchemaType}
import upickle.core.ObjVisitor

Expand Down Expand Up @@ -80,7 +79,7 @@ class PicklerEnumTest extends AnyFlatSpec with Matchers {
// given
import generic.auto.* // for Pickler auto-derivation
val inputObj = SealedVariantContainer(VariantA)

// when
val pickler = Pickler.derived[SealedVariantContainer]
val codec = pickler.toCodec
Expand Down

0 comments on commit 7088450

Please sign in to comment.