Skip to content

Commit 95955cc

Browse files
authored
Merge pull request #1 from reyoung/grad_test_for_multi_inputs
Invoke check_grad many times for no_grad_set
2 parents 4470332 + 3d9d32a commit 95955cc

File tree

3 files changed

+37
-29
lines changed

3 files changed

+37
-29
lines changed

python/paddle/v2/framework/tests/gradient_checker.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def check_grad(self,
286286
for no_grad in no_grad_set:
287287
if no_grad not in in_names:
288288
raise ValueError("no_grad should be in in_names")
289-
if name in inputs_to_check:
289+
if no_grad in inputs_to_check:
290290
raise ValueError("no_grad should not be in inputs_to_check")
291291

292292
backward_op = core.Operator.backward(forward_op, no_grad_set)
@@ -304,25 +304,8 @@ def check_grad(self,
304304

305305
check_names = [grad_var_name(name) for name in inputs_to_check]
306306
for place in places:
307-
# analytic_grads = self.__get_gradient(forward_op, backward_op,
308-
# input_vars, check_names, place)
309-
# In fact, the above two lines can be used to replace following
310-
# codes. But most of the gradient operators need to handle the case
311-
# where one of more of the gradient of the input is not needed.
312-
# We change the unit test framework to explicitly test whether
313-
# the operator correctly handles this through follow codes.
314-
# In addtion, if all the inputs have no gradients, the NOP operator
315-
# will be returned by core.Operator.backward(). The following codes
316-
# do not test this case.
317-
analytic_grads = []
318-
for name in inputs_to_check:
319-
no_grads = [name for name in no_grad_set]
320-
no_grads.extend(filter(lambda x: x != name, inputs_to_check))
321-
backward_op = core.Operator.backward(forward_op, set(no_grads))
322-
# get analytical gradients according to different device
323-
analytic_grads.extend(
324-
self.__get_gradient(forward_op, backward_op, input_vars,
325-
[grad_var_name(name)], place))
307+
analytic_grads = self.__get_gradient(forward_op, backward_op,
308+
input_vars, check_names, place)
326309
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
327310
max_relative_error,
328311
"Gradient Check On %s" % str(place))

python/paddle/v2/framework/tests/test_mul_op.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,33 @@ def setUp(self):
1717

1818

1919
class TestMulGradOp(GradientChecker):
20-
def test_mul(self):
21-
op = create_op("mul")
22-
inputs = {
20+
def setUp(self):
21+
self.op = create_op("mul")
22+
self.inputs = {
2323
'X': np.random.random((32, 84)).astype("float32"),
2424
'Y': np.random.random((84, 100)).astype("float32")
2525
}
26-
self.compare_grad(op, inputs)
26+
27+
def test_normal(self):
2728
# mul op will enlarge the relative error
2829
self.check_grad(
29-
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
30+
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
31+
32+
def test_ignore_x(self):
33+
self.check_grad(
34+
self.op,
35+
self.inputs, ["Y"],
36+
"Out",
37+
max_relative_error=0.5,
38+
no_grad_set={"X"})
39+
40+
def test_ignore_y(self):
41+
self.check_grad(
42+
self.op,
43+
self.inputs, ["X"],
44+
"Out",
45+
max_relative_error=0.5,
46+
no_grad_set={"Y"})
3047

3148

3249
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library

python/paddle/v2/framework/tests/test_rowwise_add_op.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,21 @@ def setUp(self):
1717

1818

1919
class RowwiseAddGradOpTest(GradientChecker):
20-
def test_rowwise_add(self):
21-
op = create_op("rowwise_add")
22-
inputs = {
20+
def setUp(self):
21+
self.op = create_op("rowwise_add")
22+
self.inputs = {
2323
"X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
2424
"b": np.random.uniform(0.1, 1, [10]).astype("float32")
2525
}
26-
self.check_grad(op, inputs, set(["X", "b"]), "Out")
26+
27+
def test_normal(self):
28+
self.check_grad(self.op, self.inputs, ["X", "b"], "Out")
29+
30+
def test_ignore_b(self):
31+
self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"})
32+
33+
def test_ignore_x(self):
34+
self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
2735

2836

2937
if __name__ == '__main__':

0 commit comments

Comments
 (0)