diff --git a/setup.py b/setup.py index 65be738..b4ed0f3 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ setuptools.setup( name='tree-math', description='Mathematical operations for JAX pytrees', - version='0.2.0 ', + version='0.2.1', license='Apache 2.0', author='Google LLC', author_email='noreply@google.com', diff --git a/tree_math/__init__.py b/tree_math/__init__.py index 981869e..c2e0241 100644 --- a/tree_math/__init__.py +++ b/tree_math/__init__.py @@ -24,4 +24,4 @@ from tree_math._src.vector import Vector, VectorMixin import tree_math.numpy -__version__ = '0.2.0' +__version__ = '0.2.1' diff --git a/tree_math/_src/structs.py b/tree_math/_src/structs.py index 57e135d..2a16af6 100644 --- a/tree_math/_src/structs.py +++ b/tree_math/_src/structs.py @@ -72,6 +72,7 @@ def tree_unflatten(cls, _, children): {'fields': fields, 'asdict': asdict, 'astuple': astuple, + 'replace': dataclasses.replace, 'tree_flatten': tree_flatten, 'tree_unflatten': tree_unflatten, '__module__': cls.__module__}) diff --git a/tree_math/_src/structs_test.py b/tree_math/_src/structs_test.py index 5b9670e..879e69c 100644 --- a/tree_math/_src/structs_test.py +++ b/tree_math/_src/structs_test.py @@ -109,6 +109,12 @@ def testPickle(self): restored = pickle.loads(pickle.dumps(struct)) self.assertTreeEqual(struct, restored, check_dtypes=True) + def testReplace(self): + struct = TestStruct(1, 2) + replaced = struct.replace(b=3) + expected = TestStruct(1, 3) + self.assertTreeEqual(replaced, expected, check_dtypes=True) + if __name__ == '__main__': absltest.main()