diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 29109ac151ccce..2fb61ae2731d23 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -226,6 +226,7 @@ def new_init(self, *args, **kwargs): get_autocast_gpu_dtype, is_autocast_enabled, ) +from .amp.auto_cast import autocast from .autograd import ( enable_grad, grad, @@ -971,7 +972,6 @@ def __dir__(self): sub = subtract sub_ = subtract_ - __all__ = [ 'block_diag', 'gt', @@ -1481,6 +1481,7 @@ def __dir__(self): 'conv3d', 'manual_seed', 'softmax', + 'autocast', ] import os diff --git a/python/paddle/functional.py b/python/paddle/functional.py index 96e0c5eb6106bc..6642f3867899b8 100644 --- a/python/paddle/functional.py +++ b/python/paddle/functional.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .compat import split +from .tensor.creation import meshgrid from .tensor.einsum import einsum from .tensor.linalg import norm from .tensor.manipulation import ( @@ -31,4 +32,5 @@ "norm", 'split', 'unique_consecutive', + "meshgrid", ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index db823aa97d7f1e..cd5c5e702c7245 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -172,6 +172,7 @@ pixel_unshuffle, ) +logsigmoid = log_sigmoid __all__ = [ 'celu', 'conv1d', @@ -192,6 +193,7 @@ 'leaky_relu', 'leaky_relu_', 'log_sigmoid', + 'logsigmoid', 'maxout', 'prelu', 'relu', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index b7b63d5c7c1323..b5016f298ed890 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -754,6 +754,7 @@ def relu_(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.relu_(x) +@param_one_alias(["x", "input"]) def log_sigmoid(x: Tensor, name: str | None = None) -> Tensor: r""" log_sigmoid activation. @@ -764,6 +765,7 @@ def log_sigmoid(x: Tensor, name: str | None = None) -> Tensor: Parameters: x (Tensor): The input Tensor with data type float32, float64, complex64, complex128. + Alias: ``input``. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: diff --git a/python/paddle/nn/init.py b/python/paddle/nn/init.py index ad6116ddcb64e4..5dc730c6e43e0d 100644 --- a/python/paddle/nn/init.py +++ b/python/paddle/nn/init.py @@ -14,6 +14,8 @@ from __future__ import annotations +import numpy as np + import paddle from ..base.framework import in_dygraph_mode, in_pir_mode @@ -27,6 +29,41 @@ from .initializer.xavier import XavierNormal, XavierUniform +def _calculate_fan_in_and_fan_out(var: paddle.Tensor) -> tuple[int, int]: + """Compute the fan_in and the fan_out for layers + + This method computes the fan_in and the fan_out + for neural network layers, if not specified. It is + not possible to perfectly estimate fan_in and fan_out. + This method will estimate it correctly for matrix multiply and + convolutions. + + Args: + var: variable for which fan_in and fan_out have to be computed. + + Returns: + tuple of two integers (fan_in, fan_out). + """ + shape = var.shape + if not shape or len(shape) == 0: + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + # This is the case for simple matrix multiply + fan_in = shape[0] + fan_out = shape[1] + else: + # Assume this to be a convolutional kernel + # In PaddlePaddle, the shape of the kernel is like: + # [num_filters, num_filter_channels, ...] where the remaining + # dimensions are the filter_size + receptive_field_size = np.prod(shape[2:]) + fan_in = int(shape[1] * receptive_field_size) + fan_out = int(shape[0] * receptive_field_size) + return (fan_in, fan_out) + + def kaiming_uniform_( tensor: paddle.Tensor, a: float = 0, diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b6d3d3bdc50847..026b051b337104 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -515,6 +515,7 @@ greater = gt sub = subtract sub_ = subtract_ +clamp_ = clip_ # this list used in math_op_patch.py for _binary_creator_ tensor_method_func = [ @@ -947,6 +948,7 @@ 'gt', 'greater', 'clamp', + 'clamp_', ] diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index a72407c157555f..705a16896b992c 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -854,6 +854,53 @@ def test_errors(self): F.log_sigmoid(x_fp16) +class TestLogSigmoidOutAndParaDecorator(unittest.TestCase): + def setUp(self) -> None: + paddle.disable_static() + self.apis = [ + paddle.nn.functional.log_sigmoid, + paddle.nn.functional.logsigmoid, + ] + self.shape = [3, 4, 5] + self.input_np = np.random.random(self.shape).astype('float32') + + def do_test(self, api, test_type): + self.test_types = [ + "decorator1", + ] + x = paddle.to_tensor(self.input_np, stop_gradient=False) + out = paddle.zeros(self.shape, dtype='float32') + out.stop_gradient = False + if test_type == "raw": + out = paddle.nn.functional.log_sigmoid(x) + out.mean().backward() + return out, x.grad + elif test_type == "decorator1": + res = api(input=x) + loss = res.mean() + loss.backward() + x_grad = x.grad + return res, x_grad + else: + raise NotImplementedError( + f"Test type {test_type} is not implemented." + ) + + def test_api(self): + out_std, x_grad_std = self.do_test( + paddle.nn.functional.log_sigmoid, "raw" + ) + for api in self.apis: + for test_type in self.test_types: + out, x_grad = self.do_test(api, test_type) + np.testing.assert_allclose( + out.numpy(), out_std.numpy(), rtol=1e-20 + ) + np.testing.assert_allclose( + x_grad.numpy(), x_grad_std.numpy(), rtol=1e-20 + ) + + class TestTanh(TestActivation, TestParameter): def setUp(self): self.op_type = "tanh" diff --git a/test/legacy_test/test_autocast.py b/test/legacy_test/test_autocast.py new file mode 100644 index 00000000000000..e4d16b1b6211c8 --- /dev/null +++ b/test/legacy_test/test_autocast.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.base import core + + +@unittest.skipIf( + not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(), + "Require compiled with CUDA or XPU.", +) +@unittest.skipIf( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] < 7.0, + "run test when gpu's compute capability is at least 7.0.", +) +@unittest.skipIf( + core.is_compiled_with_xpu() + and core.get_xpu_device_version(0) < core.XPUVersion.XPU3, + "run test when xpu's compute capability >= xpu3.", +) +@unittest.skipIf( + core.is_compiled_with_xpu() + and core.get_xpu_device_version(0) == core.XPUVersion.XPU3, + "Bugs on XPU3, disable temporarily", +) +class TestCudaAutoCast(unittest.TestCase): + def setUp(self): + self._conv = paddle.nn.Conv2D(1, 1, 3, bias_attr=False) + self._linear = paddle.nn.Linear(4, 4) + + def _run_autocast_test(self, ctx): + with paddle.autocast( + device_type='cuda', + enabled=True, + dtype=paddle.float16, + cache_enabled=True, + ): + out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) + out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16') + out3 = self._linear(out2) + + self.assertEqual(out1.dtype, paddle.float16) + self.assertEqual(out2.dtype, paddle.float16) + self.assertEqual(out3.dtype, paddle.float32) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_clip_op.py b/test/legacy_test/test_clip_op.py index de37d48303782c..480f08c59e3f41 100644 --- a/test/legacy_test/test_clip_op.py +++ b/test/legacy_test/test_clip_op.py @@ -1033,5 +1033,37 @@ def test_static_compatibility(self): np.testing.assert_array_equal(self.np_out, fetches[0]) +class TestClamp_AndClip_(unittest.TestCase): + def setUp(self) -> None: + paddle.disable_static() + self.shape = [3, 4, 5] + self.input_np = np.random.random(self.shape).astype('float32') + self.a = np.random.random(self.shape).astype('float32') + self.b = np.random.random(self.shape).astype('float32') + self.min, self.max = -0.5, 0.5 + + def test_clip_and_clamp(self): + clip_a = paddle.to_tensor(self.a, stop_gradient=False) + clip_b = paddle.to_tensor(self.b, stop_gradient=False) + + clamp_a = paddle.to_tensor(self.a, stop_gradient=False) + clamp_b = paddle.to_tensor(self.b, stop_gradient=False) + + clip_x = clip_a + clip_b + clip_x.clip_(min=self.min, max=self.max) + clip_x.retain_grads() + clip_x.mean().backward() + + clamp_x = clamp_a + clamp_b + clamp_x.clamp_(min=self.min, max=self.max) + clamp_x.retain_grads() + clamp_x.mean().backward() + + np.testing.assert_allclose(clip_x.numpy(), clamp_x.numpy(), rtol=1e-20) + np.testing.assert_allclose( + clip_x.grad.numpy(), clamp_x.grad.numpy(), rtol=1e-20 + ) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_nn_init_function.py b/test/legacy_test/test_nn_init_function.py index fb21baacb72e72..ad6ccf89c020b7 100644 --- a/test/legacy_test/test_nn_init_function.py +++ b/test/legacy_test/test_nn_init_function.py @@ -62,6 +62,22 @@ def _calculate_gain(nonlinearity, param): return recommended_gain[nonlinearity] +def _calculate_fan_in_and_fan_out(var: paddle.Tensor) -> tuple[int, int]: + shape = var.shape + if not shape or len(shape) == 0: + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + fan_in = shape[0] + fan_out = shape[1] + else: + receptive_field_size = np.prod(shape[2:]) + fan_in = shape[1] * receptive_field_size + fan_out = shape[0] * receptive_field_size + return (fan_in, fan_out) + + class Test_calculate_gain(unittest.TestCase): def test(self): for nonlinearity in [ @@ -87,6 +103,27 @@ def test(self): ) +class TestCAlFanINOUT(unittest.TestCase): + def test_cal_fan_in_and_out(self): + x = paddle.tensor.randn([10]) + self.assertEqual( + _calculate_fan_in_and_fan_out(x), + paddle.nn.init._calculate_fan_in_and_fan_out(x), + ) + + y = paddle.tensor.randn([10, 10]) + self.assertEqual( + _calculate_fan_in_and_fan_out(y), + paddle.nn.init._calculate_fan_in_and_fan_out(y), + ) + + z = paddle.randn([10, 10, 10]) + self.assertEqual( + _calculate_fan_in_and_fan_out(z), + paddle.nn.init._calculate_fan_in_and_fan_out(z), + ) + + class Test_kaiming_uniform_(unittest.TestCase): def check_kaiming_uniform( self, tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'