Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ace5747
sharding stage3 bugfix
AlAuAu Oct 22, 2025
4911b59
Merge branch 'PaddlePaddle:develop' into develop
AlAuAu Oct 27, 2025
6b24646
sharding stage3 bugfix
AlAuAu Oct 22, 2025
e054b48
sharding stage3 bugfix
AlAuAu Oct 22, 2025
fbde952
sharding stage3 bugfix
AlAuAu Oct 22, 2025
333bae9
sharding stage3 bugfix
AlAuAu Oct 22, 2025
143d785
sharding stage3 bugfix
AlAuAu Oct 22, 2025
183a930
Merge branch 'PaddlePaddle:develop' into develop
AlAuAu Nov 3, 2025
312fc19
Merge branch 'PaddlePaddle:develop' into develop
AlAuAu Nov 3, 2025
7fd48d9
support recompute's forward and backward in pipeline mode
AlAuAu Nov 3, 2025
f97edd4
Merge branch 'PaddlePaddle:develop' into develop
AlAuAu Nov 4, 2025
025efc3
[API Compatibility] Add paddle.Tensor.clip_
AlAuAu Nov 4, 2025
dd52286
Merge branch 'PaddlePaddle:develop' into develop
AlAuAu Nov 5, 2025
4e54642
Revert "support recompute's forward and backward in pipeline mode"
AlAuAu Nov 5, 2025
f45380c
Revert "[API Compatibility] Add paddle.Tensor.clip_"
AlAuAu Nov 5, 2025
dc28d5e
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
a8f7186
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
7fcdd48
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
6431d8f
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
9d93a48
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
fdb2a25
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
b2f67a7
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
fdce602
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
767effe
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
97e5dc1
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
b228dce
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
56deece
[API Compatibility] Add clip_、logsigmoid、_calculate_fan_in_and_fan_ou…
AlAuAu Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def new_init(self, *args, **kwargs):
get_autocast_gpu_dtype,
is_autocast_enabled,
)
from .amp.auto_cast import autocast
from .autograd import (
enable_grad,
grad,
Expand Down Expand Up @@ -971,7 +972,6 @@ def __dir__(self):
sub = subtract
sub_ = subtract_


__all__ = [
'block_diag',
'gt',
Expand Down Expand Up @@ -1481,6 +1481,7 @@ def __dir__(self):
'conv3d',
'manual_seed',
'softmax',
'autocast',
]
import os

Expand Down
2 changes: 2 additions & 0 deletions python/paddle/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .compat import split
from .tensor.creation import meshgrid
from .tensor.einsum import einsum
from .tensor.linalg import norm
from .tensor.manipulation import (
Expand All @@ -31,4 +32,5 @@
"norm",
'split',
'unique_consecutive',
"meshgrid",
]
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
pixel_unshuffle,
)

logsigmoid = log_sigmoid
__all__ = [
'celu',
'conv1d',
Expand All @@ -192,6 +193,7 @@
'leaky_relu',
'leaky_relu_',
'log_sigmoid',
'logsigmoid',
'maxout',
'prelu',
'relu',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def relu_(x: Tensor, name: str | None = None) -> Tensor:
return _C_ops.relu_(x)


@param_one_alias(["x", "input"])
def log_sigmoid(x: Tensor, name: str | None = None) -> Tensor:
r"""
log_sigmoid activation.
Expand All @@ -764,6 +765,7 @@ def log_sigmoid(x: Tensor, name: str | None = None) -> Tensor:

Parameters:
x (Tensor): The input Tensor with data type float32, float64, complex64, complex128.
Alias: ``input``.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import numpy as np

import paddle

from ..base.framework import in_dygraph_mode, in_pir_mode
Expand All @@ -27,6 +29,41 @@
from .initializer.xavier import XavierNormal, XavierUniform


def _calculate_fan_in_and_fan_out(var: paddle.Tensor) -> tuple[int, int]:
"""Compute the fan_in and the fan_out for layers

This method computes the fan_in and the fan_out
for neural network layers, if not specified. It is
not possible to perfectly estimate fan_in and fan_out.
This method will estimate it correctly for matrix multiply and
convolutions.

Args:
var: variable for which fan_in and fan_out have to be computed.

Returns:
tuple of two integers (fan_in, fan_out).
"""
shape = var.shape
if not shape or len(shape) == 0:
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
# This is the case for simple matrix multiply
fan_in = shape[0]
fan_out = shape[1]
else:
# Assume this to be a convolutional kernel
# In PaddlePaddle, the shape of the kernel is like:
# [num_filters, num_filter_channels, ...] where the remaining
# dimensions are the filter_size
receptive_field_size = np.prod(shape[2:])
fan_in = int(shape[1] * receptive_field_size)
fan_out = int(shape[0] * receptive_field_size)
return (fan_in, fan_out)


def kaiming_uniform_(
tensor: paddle.Tensor,
a: float = 0,
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@
greater = gt
sub = subtract
sub_ = subtract_
clamp_ = clip_

# this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [
Expand Down Expand Up @@ -947,6 +948,7 @@
'gt',
'greater',
'clamp',
'clamp_',
]


Expand Down
47 changes: 47 additions & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,53 @@ def test_errors(self):
F.log_sigmoid(x_fp16)


class TestLogSigmoidOutAndParaDecorator(unittest.TestCase):
def setUp(self) -> None:
paddle.disable_static()
self.apis = [
paddle.nn.functional.log_sigmoid,
paddle.nn.functional.logsigmoid,
]
self.shape = [3, 4, 5]
self.input_np = np.random.random(self.shape).astype('float32')

def do_test(self, api, test_type):
self.test_types = [
"decorator1",
]
x = paddle.to_tensor(self.input_np, stop_gradient=False)
out = paddle.zeros(self.shape, dtype='float32')
out.stop_gradient = False
if test_type == "raw":
out = paddle.nn.functional.log_sigmoid(x)
out.mean().backward()
return out, x.grad
elif test_type == "decorator1":
res = api(input=x)
loss = res.mean()
loss.backward()
x_grad = x.grad
return res, x_grad
else:
raise NotImplementedError(
f"Test type {test_type} is not implemented."
)

def test_api(self):
out_std, x_grad_std = self.do_test(
paddle.nn.functional.log_sigmoid, "raw"
)
for api in self.apis:
for test_type in self.test_types:
out, x_grad = self.do_test(api, test_type)
np.testing.assert_allclose(
out.numpy(), out_std.numpy(), rtol=1e-20
)
np.testing.assert_allclose(
x_grad.numpy(), x_grad_std.numpy(), rtol=1e-20
)


class TestTanh(TestActivation, TestParameter):
def setUp(self):
self.op_type = "tanh"
Expand Down
62 changes: 62 additions & 0 deletions test/legacy_test/test_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle.base import core


@unittest.skipIf(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的skipIf怎么写了这么多

not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(),
"Require compiled with CUDA or XPU.",
)
@unittest.skipIf(
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
@unittest.skipIf(
core.is_compiled_with_xpu()
and core.get_xpu_device_version(0) < core.XPUVersion.XPU3,
"run test when xpu's compute capability >= xpu3.",
)
@unittest.skipIf(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个xpu是不是重复的

core.is_compiled_with_xpu()
and core.get_xpu_device_version(0) == core.XPUVersion.XPU3,
"Bugs on XPU3, disable temporarily",
)
class TestCudaAutoCast(unittest.TestCase):
def setUp(self):
self._conv = paddle.nn.Conv2D(1, 1, 3, bias_attr=False)
self._linear = paddle.nn.Linear(4, 4)

def _run_autocast_test(self, ctx):
with paddle.autocast(
device_type='cuda',
enabled=True,
dtype=paddle.float16,
cache_enabled=True,
):
out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = self._linear(out2)

self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float16)
self.assertEqual(out3.dtype, paddle.float32)


if __name__ == '__main__':
unittest.main()
32 changes: 32 additions & 0 deletions test/legacy_test/test_clip_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,5 +1033,37 @@ def test_static_compatibility(self):
np.testing.assert_array_equal(self.np_out, fetches[0])


class TestClamp_AndClip_(unittest.TestCase):
def setUp(self) -> None:
paddle.disable_static()
self.shape = [3, 4, 5]
self.input_np = np.random.random(self.shape).astype('float32')
self.a = np.random.random(self.shape).astype('float32')
self.b = np.random.random(self.shape).astype('float32')
self.min, self.max = -0.5, 0.5

def test_clip_and_clamp(self):
clip_a = paddle.to_tensor(self.a, stop_gradient=False)
clip_b = paddle.to_tensor(self.b, stop_gradient=False)

clamp_a = paddle.to_tensor(self.a, stop_gradient=False)
clamp_b = paddle.to_tensor(self.b, stop_gradient=False)

clip_x = clip_a + clip_b
clip_x.clip_(min=self.min, max=self.max)
clip_x.retain_grads()
clip_x.mean().backward()

clamp_x = clamp_a + clamp_b
clamp_x.clamp_(min=self.min, max=self.max)
clamp_x.retain_grads()
clamp_x.mean().backward()

np.testing.assert_allclose(clip_x.numpy(), clamp_x.numpy(), rtol=1e-20)
np.testing.assert_allclose(
clip_x.grad.numpy(), clamp_x.grad.numpy(), rtol=1e-20
)


if __name__ == '__main__':
unittest.main()
37 changes: 37 additions & 0 deletions test/legacy_test/test_nn_init_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ def _calculate_gain(nonlinearity, param):
return recommended_gain[nonlinearity]


def _calculate_fan_in_and_fan_out(var: paddle.Tensor) -> tuple[int, int]:
shape = var.shape
if not shape or len(shape) == 0:
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
fan_in = shape[0]
fan_out = shape[1]
else:
receptive_field_size = np.prod(shape[2:])
fan_in = shape[1] * receptive_field_size
fan_out = shape[0] * receptive_field_size
return (fan_in, fan_out)


class Test_calculate_gain(unittest.TestCase):
def test(self):
for nonlinearity in [
Expand All @@ -87,6 +103,27 @@ def test(self):
)


class TestCAlFanINOUT(unittest.TestCase):
def test_cal_fan_in_and_out(self):
x = paddle.tensor.randn([10])
self.assertEqual(
_calculate_fan_in_and_fan_out(x),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不用重新复制一个api过来,可以直接算好expect的值

paddle.nn.init._calculate_fan_in_and_fan_out(x),
)

y = paddle.tensor.randn([10, 10])
self.assertEqual(
_calculate_fan_in_and_fan_out(y),
paddle.nn.init._calculate_fan_in_and_fan_out(y),
)

z = paddle.randn([10, 10, 10])
self.assertEqual(
_calculate_fan_in_and_fan_out(z),
paddle.nn.init._calculate_fan_in_and_fan_out(z),
)


class Test_kaiming_uniform_(unittest.TestCase):
def check_kaiming_uniform(
self, tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'
Expand Down
Loading