diff --git a/build.sbt b/build.sbt index 99981ee3a0..7060be699a 100644 --- a/build.sbt +++ b/build.sbt @@ -10,7 +10,7 @@ val disciplineMunitVersion = "1.0.9" val flywayVersion = "9.20.0" val fs2AwsVersion = "6.2.0" val fs2Version = "3.12.0" -val grackleVersion = "0.24.0" +val grackleVersion = "0.24.0-28-d64fd49-20250702T151401Z-SNAPSHOT" //"0.24.0" val http4sBlazeVersion = "0.23.17" val http4sEmberVersion = "0.23.30" val http4sJdkHttpClientVersion = "0.10.0" diff --git a/modules/service/src/main/scala/lucuma/odb/graphql/OdbMapping.scala b/modules/service/src/main/scala/lucuma/odb/graphql/OdbMapping.scala index 76cb2c3847..3213eeb98c 100644 --- a/modules/service/src/main/scala/lucuma/odb/graphql/OdbMapping.scala +++ b/modules/service/src/main/scala/lucuma/odb/graphql/OdbMapping.scala @@ -102,6 +102,7 @@ object OdbMapping { shouldValidate:Boolean = true, // should we validatate the TypeMappings? ): Mapping[F] = new SkunkMapping[F](database, monitor0) + // with FetchLimit[F](10000) with BaseMapping[F] with ArcMapping[F] with AddAtomEventResultMapping[F] @@ -597,7 +598,7 @@ object OdbMapping { super.defaultRootCursor(query, tpe, parentCursor) // Override `fetch` to log the SQL query. This is optional. - override def fetch(fragment: AppliedFragment, codecs: List[(Boolean, Codec)]): F[Vector[Array[Any]]] = { + override def fetch(fragment: AppliedFragment, codecs: List[(Boolean, Codec)]): F[Result[Vector[Array[Any]]]] = { Logger[F].debug { val formatted = SqlFormatter.format(fragment.fragment.sql) val cleanedUp = formatted.replaceAll("\\$ (\\d+)", "\\$$1") // turn $ 42 into $42 @@ -606,7 +607,7 @@ object OdbMapping { } *> super.fetch(fragment, codecs) } - + // HACK: If the codec is a DomainCodec then use the domain name when generating `null::` in Grackle queries override implicit def Fragments: SqlFragment[AppliedFragment] = val delegate = super.Fragments diff --git a/modules/service/src/main/scala/lucuma/odb/graphql/util/FetchLimit.scala b/modules/service/src/main/scala/lucuma/odb/graphql/util/FetchLimit.scala new file mode 100644 index 0000000000..2fa8bee62b --- /dev/null +++ b/modules/service/src/main/scala/lucuma/odb/graphql/util/FetchLimit.scala @@ -0,0 +1,45 @@ +// Copyright (c) 2016-2025 Association of Universities for Research in Astronomy, Inc. (AURA) +// For license information see LICENSE or https://opensource.org/licenses/BSD-3-Clause + +package lucuma.odb.graphql.util + +import cats.effect.Async +import cats.effect.Resource +import cats.syntax.all.* +import grackle.Result +import grackle.skunk.SkunkMapping +import scala.util.control.NonFatal +import lucuma.odb.data.OdbError +import lucuma.odb.data.OdbErrorExtensions.* +import skunk.Decoder + +trait FetchLimit[F[_]: Async](maxBytes: Long) extends SkunkMapping[F]: + + case class FetchLimitError(maxBytes: Long) extends Exception(s"Fetch limit ($maxBytes bytes) exceeded.") + + // Given a decoder, return an equivalent decoder that also yields the total length of the underly row data + def countingDecoder[A](dec: Decoder[A]): Decoder[(A, Long)] = + new Decoder[(A, Long)]: + override def types = dec.types + override def decode(offset: Int, ss: List[Option[String]]): Either[Decoder.Error, (A, Long)] = + dec.decode(offset, ss).map((_, ss.foldMap(_.foldMap(_.length.toLong)))) + + override def fetch(fragment: Fragment, codecs: List[(Boolean, Codec)]): F[Result[Vector[Array[Any]]]] = + pool + .use: s => + Resource.eval(s.prepare(fragment.fragment.query(countingDecoder(rowDecoder(codecs))))).use: ps => + ps.stream(fragment.argument, 1024) + .evalMapAccumulate(0L): + case (prev, (arr, size)) => + val next = prev + size + Async[F].raiseWhen(next > maxBytes)(FetchLimitError(maxBytes)) >> + (next, arr).pure[F] + .compile + .toVector + .map: vec => + Result.success(vec.map(_._2)) + .recover: + case e: FetchLimitError => OdbError.RemoteServiceCallError(Some(e.getMessage())).asFailure + .onError: + case NonFatal(e) => Async[F].delay(e.printStackTrace()) +