diff --git a/test/test_pytato.py b/test/test_pytato.py index da176f124..f22b4e2f5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1844,6 +1844,34 @@ def test_pickling_hash(): # }}} +def test_einsum_tagging(): + from pytools.tag import UniqueTag + + class FTag(UniqueTag): + pass + + class ATag(UniqueTag): + pass + + class CTag(UniqueTag): + pass + + class DTag(UniqueTag): + pass + + p = (pt.zeros((2, 3, 4, 5)) + .with_tagged_axis(0, FTag()) + .with_tagged_axis(1, ATag()) + .with_tagged_axis(2, CTag()) + .with_tagged_axis(3, DTag())) + + a = pt.zeros((2, 3, 3)) + + result = pt.einsum("facb, fdce, fad -> cbe", p, p, a) + + pt.unify_axes_tags(result) + + if __name__ == "__main__": import os if "INVOCATION_INFO" in os.environ: