@@ -286,7 +286,7 @@ def check_grad(self,
286286 for no_grad in no_grad_set :
287287 if no_grad not in in_names :
288288 raise ValueError ("no_grad should be in in_names" )
289- if name in inputs_to_check :
289+ if no_grad in inputs_to_check :
290290 raise ValueError ("no_grad should not be in inputs_to_check" )
291291
292292 backward_op = core .Operator .backward (forward_op , no_grad_set )
@@ -304,25 +304,8 @@ def check_grad(self,
304304
305305 check_names = [grad_var_name (name ) for name in inputs_to_check ]
306306 for place in places :
307- # analytic_grads = self.__get_gradient(forward_op, backward_op,
308- # input_vars, check_names, place)
309- # In fact, the above two lines can be used to replace following
310- # codes. But most of the gradient operators need to handle the case
311- # where one of more of the gradient of the input is not needed.
312- # We change the unit test framework to explicitly test whether
313- # the operator correctly handles this through follow codes.
314- # In addtion, if all the inputs have no gradients, the NOP operator
315- # will be returned by core.Operator.backward(). The following codes
316- # do not test this case.
317- analytic_grads = []
318- for name in inputs_to_check :
319- no_grads = [name for name in no_grad_set ]
320- no_grads .extend (filter (lambda x : x != name , inputs_to_check ))
321- backward_op = core .Operator .backward (forward_op , set (no_grads ))
322- # get analytical gradients according to different device
323- analytic_grads .extend (
324- self .__get_gradient (forward_op , backward_op , input_vars ,
325- [grad_var_name (name )], place ))
307+ analytic_grads = self .__get_gradient (forward_op , backward_op ,
308+ input_vars , check_names , place )
326309 self .__assert_is_close (numeric_grads , analytic_grads , check_names ,
327310 max_relative_error ,
328311 "Gradient Check On %s" % str (place ))
0 commit comments