diff --git a/ai-core/src/main/scala/wvlet/ai/core/weaver/CaseClassWeaver.scala b/ai-core/src/main/scala/wvlet/ai/core/weaver/CaseClassWeaver.scala new file mode 100644 index 0000000..88285ba --- /dev/null +++ b/ai-core/src/main/scala/wvlet/ai/core/weaver/CaseClassWeaver.scala @@ -0,0 +1,101 @@ +package wvlet.ai.core.weaver + +import scala.deriving.Mirror // Keep Mirror for `m` +// erasedValue, summonInline, constValue, error are no longer needed here +import wvlet.ai.core.msgpack.spi.{Packer, Unpacker} + +// Removed duplicate ObjectWeaver trait. +// The canonical one is in ObjectWeaver.scala + +/** + * Custom exception for errors occurring during weaver packing. + * @param message + * A description of the error. + * @param cause + * The underlying cause of the error, if any. + */ +case class WeaverPackingException(message: String, cause: Throwable = null) + extends RuntimeException(message, cause) + +// Companion object removed for this attempt + +// Constructor now accepts elementWeavers. Mirror m is still needed for fromProduct. +class CaseClassWeaver[A](private val elementWeavers: List[ObjectWeaver[?]])(using + m: Mirror.ProductOf[A] +) extends ObjectWeaver[A]: + + // Internal buildWeavers and elementWeavers val are removed. + + override def pack(packer: Packer, v: A, config: WeaverConfig): Unit = + val product = v.asInstanceOf[Product] + if product.productArity != elementWeavers.size then + throw WeaverPackingException( + s"Element count mismatch. Expected: ${elementWeavers.size}, Got: ${product.productArity}" + ) + packer.packArrayHeader(elementWeavers.size) + + product + .productIterator + .zip(elementWeavers) + .foreach { case (elemValue, weaver) => + (weaver.asInstanceOf[ObjectWeaver[Any]]).pack(packer, elemValue, config) + } + + override def unpack(unpacker: Unpacker, context: WeaverContext): Unit = + val numElements = unpacker.unpackArrayHeader + if numElements != elementWeavers.size then + context.setError( + new IllegalArgumentException( + s"Element count mismatch. Expected: ${elementWeavers.size}, Got: ${numElements}" + ) + ) + // This point is for future consideration of schema evolution or robust error recovery. + // For now, strict element count matching is enforced. + return + + val elements = new Array[Any](elementWeavers.size) + var i = 0 + var failed = false + + while i < elementWeavers.size && !failed do + val weaver = elementWeavers(i) + val elementContext = WeaverContext(context.config) + // Assuming weaver is ObjectWeaver[?] so direct call is not possible without cast + // However, the element type is unknown here to do a safe cast. + // This part of unpack will need careful handling if we stick to List[ObjectWeaver[?]] + (weaver.asInstanceOf[ObjectWeaver[Any]]).unpack(unpacker, elementContext) + + if elementContext.hasError then + context.setError( + new RuntimeException( + s"Failed to unpack element $i: ${elementContext.getError.get.getMessage}", + elementContext.getError.get + ) + ) + failed = true + else + elements(i) = elementContext.getLastValue + i += 1 + + if !failed then + try + val instance = m.fromProduct( + new Product: + override def productArity: Int = elements.length + override def productElement(n: Int): Any = elements(n) + override def canEqual(that: Any): Boolean = + that.isInstanceOf[Product] && that.asInstanceOf[Product].productArity == productArity + ) + context.setObject(instance) + catch + case e: Throwable => + context.setError(new RuntimeException("Failed to instantiate case class from product", e)) + // Closing brace for try-catch + // Closing brace for if (!failed) + // If failed, context will already have an error set. + // Closing brace for unpack method + end unpack + + // Closing brace for CaseClassWeaver class + +end CaseClassWeaver diff --git a/ai-core/src/main/scala/wvlet/ai/core/weaver/ObjectWeaver.scala b/ai-core/src/main/scala/wvlet/ai/core/weaver/ObjectWeaver.scala index f0ee2aa..c9a420f 100644 --- a/ai-core/src/main/scala/wvlet/ai/core/weaver/ObjectWeaver.scala +++ b/ai-core/src/main/scala/wvlet/ai/core/weaver/ObjectWeaver.scala @@ -2,6 +2,8 @@ package wvlet.ai.core.weaver import wvlet.ai.core.msgpack.spi.{MessagePack, MsgPack, Packer, Unpacker} import wvlet.ai.core.weaver.codec.{JSONWeaver, PrimitiveWeaver} +import scala.deriving.Mirror +import scala.compiletime.{constValue, summonInline} trait ObjectWeaver[A]: def weave(v: A, config: WeaverConfig = WeaverConfig()): MsgPack = toMsgPack(v, config) @@ -61,3 +63,19 @@ object ObjectWeaver: ): A = weaver.fromJson(json, config) export PrimitiveWeaver.given + + private inline def buildWeaverList[ElemTypes <: Tuple]( + idx: Int + ): List[ObjectWeaver[?]] = // Removed inline from idx + inline if idx >= constValue[Tuple.Size[ElemTypes]] then // Base case: index out of bounds + Nil + else + // Summons ObjectWeaver for the element type at the current index + val headWeaver = summonInline[ObjectWeaver[Tuple.Elem[ElemTypes, idx.type]]] + headWeaver :: buildWeaverList[ElemTypes](idx + 1) // Recursive call + + inline given [A](using m: Mirror.ProductOf[A]): ObjectWeaver[A] = + val weavers = buildWeaverList[m.MirroredElemTypes](0) + new CaseClassWeaver[A](weavers)(using m) + +end ObjectWeaver diff --git a/ai-core/src/main/scala/wvlet/ai/core/weaver/codec/PrimitiveWeaver.scala b/ai-core/src/main/scala/wvlet/ai/core/weaver/codec/PrimitiveWeaver.scala index 7fb7613..d6fa842 100644 --- a/ai-core/src/main/scala/wvlet/ai/core/weaver/codec/PrimitiveWeaver.scala +++ b/ai-core/src/main/scala/wvlet/ai/core/weaver/codec/PrimitiveWeaver.scala @@ -610,4 +610,25 @@ object PrimitiveWeaver: u.skipValue context.setError(new IllegalArgumentException(s"Cannot convert ${other} to ListMap")) + inline given optionWeaver[T](using elementWeaver: => ObjectWeaver[T]): ObjectWeaver[Option[T]] = + new ObjectWeaver[Option[T]]: + override def pack(p: Packer, v: Option[T], config: WeaverConfig): Unit = + v match + case Some(value) => + elementWeaver.pack(p, value, config) + case None => + p.packNil // Corrected: removed parentheses + + override def unpack(u: Unpacker, context: WeaverContext): Unit = + if u.tryUnpackNil then + context.setObject(None) + else + // Need a fresh context for the element, in case of error or nested structures + val elementContext = WeaverContext(context.config) + elementWeaver.unpack(u, elementContext) + if elementContext.hasError then + context.setError(elementContext.getError.get) + else + context.setObject(Some(elementContext.getLastValue.asInstanceOf[T])) + end PrimitiveWeaver diff --git a/ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala b/ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala index d2c69b4..eddf2b9 100644 --- a/ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala +++ b/ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala @@ -1,8 +1,15 @@ package wvlet.ai.core.weaver import wvlet.airspec.AirSpec +import wvlet.ai.core.weaver.ObjectWeaver // Ensure ObjectWeaver is imported if not already fully covered import scala.jdk.CollectionConverters.* +// Define case classes for testing +case class SimpleCase(i: Int, s: String, b: Boolean) +case class NestedCase(name: String, simple: SimpleCase) +case class OptionCase(id: Int, opt: Option[String]) +case class SeqCase(key: String, values: Seq[Int]) + class WeaverTest extends AirSpec: test("weave int") { @@ -497,4 +504,95 @@ class WeaverTest extends AirSpec: result.get.getMessage.contains("Cannot convert") shouldBe true } + // Tests for SimpleCase + test("weave SimpleCase") { + val v = SimpleCase(10, "test case", true) + val msgpack = ObjectWeaver.weave(v) + val v2 = ObjectWeaver.unweave[SimpleCase](msgpack) + v shouldBe v2 + } + + test("SimpleCase toJson") { + val v = SimpleCase(20, "json test", false) + val json = ObjectWeaver.toJson(v) + val v2 = ObjectWeaver.fromJson[SimpleCase](json) + v shouldBe v2 + } + + // Tests for NestedCase + test("weave NestedCase") { + val v = NestedCase("nested", SimpleCase(30, "inner", true)) + val msgpack = ObjectWeaver.weave(v) + val v2 = ObjectWeaver.unweave[NestedCase](msgpack) + v shouldBe v2 + } + + test("NestedCase toJson") { + val v = NestedCase("nested json", SimpleCase(40, "inner json", false)) + val json = ObjectWeaver.toJson(v) + val v2 = ObjectWeaver.fromJson[NestedCase](json) + v shouldBe v2 + } + + // Tests for OptionCase + test("weave OptionCase with Some") { + val v = OptionCase(50, Some("option value")) + val msgpack = ObjectWeaver.weave(v) + val v2 = ObjectWeaver.unweave[OptionCase](msgpack) + v shouldBe v2 + } + + test("OptionCase toJson with Some") { + val v = OptionCase(60, Some("option json")) + val json = ObjectWeaver.toJson(v) + val v2 = ObjectWeaver.fromJson[OptionCase](json) + v shouldBe v2 + } + + test("weave OptionCase with None") { + val v = OptionCase(70, None) + val msgpack = ObjectWeaver.weave(v) + val v2 = ObjectWeaver.unweave[OptionCase](msgpack) + v shouldBe v2 + } + + test("OptionCase toJson with None") { + val v = OptionCase(80, None) + val json = ObjectWeaver.toJson(v) + // Check against expected JSON for None, as direct None might be ambiguous for fromJson + // Depending on JSON library, None might be represented as null or omitted + // Assuming it's represented as null or handled by the weaver + val v2 = ObjectWeaver.fromJson[OptionCase](json) + v shouldBe v2 + } + + // Tests for SeqCase + test("weave SeqCase with non-empty Seq") { + val v = SeqCase("seq test", Seq(1, 2, 3, 4)) + val msgpack = ObjectWeaver.weave(v) + val v2 = ObjectWeaver.unweave[SeqCase](msgpack) + v shouldBe v2 + } + + test("SeqCase toJson with non-empty Seq") { + val v = SeqCase("seq json", Seq(5, 6, 7)) + val json = ObjectWeaver.toJson(v) + val v2 = ObjectWeaver.fromJson[SeqCase](json) + v shouldBe v2 + } + + test("weave SeqCase with empty Seq") { + val v = SeqCase("empty seq", Seq.empty[Int]) + val msgpack = ObjectWeaver.weave(v) + val v2 = ObjectWeaver.unweave[SeqCase](msgpack) + v shouldBe v2 + } + + test("SeqCase toJson with empty Seq") { + val v = SeqCase("empty seq json", Seq.empty[Int]) + val json = ObjectWeaver.toJson(v) + val v2 = ObjectWeaver.fromJson[SeqCase](json) + v shouldBe v2 + } + end WeaverTest