Commit e82d2ba4 by Ting PAN

Fix the circular reference of GradientTape

Summary:
This commit removes the unused reference of GradientTape to
avoid the circular reference issue.
1 parent 9de0f1a3
#include "dragon/core/workspace.h"
#include "dragon/operators/array/reshape_ops.h"
#include "dragon/utils/math_functions.h"
......@@ -17,6 +18,11 @@ DEPLOY_CPU_OPERATOR(Identity);
DEPLOY_CUDA_OPERATOR(Identity);
#endif
DEPLOY_CPU_OPERATOR(IdentityGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(IdentityGradient);
#endif
OPERATOR_SCHEMA(Identity)
/* X */
.NumInputs(1)
......
......@@ -30,9 +30,10 @@ from dragon.core.util import tls
class Tape(object):
def __init__(self, parent):
"""Tape instance."""
def __init__(self):
self._defs = []
self._parent = parent
self._watched = set()
self._empty_grads = set()
self._gc = workspace.get_workspace().collectors
......@@ -137,6 +138,8 @@ class GradientTape(object):
raise RuntimeError(
'GradientTape.gradient(...) can only be called '
'once on non-persistent tapes.')
# Stop recording if not persistent.
if self._recording:
if not self._persistent:
self._pop_tape()
......@@ -212,7 +215,7 @@ class GradientTape(object):
if self._recording:
raise ValueError('Tape is already recording.')
if self._tape is None:
self._tape = Tape(self)
self._tape = Tape()
push_new_tape(self._tape)
self._recording = True
......
......@@ -660,8 +660,11 @@ class TestArrayOps(OpTestCase):
with execution_context().mode(execution):
data = arange((4,))
x = new_tensor(data)
y = dragon.identity(x)
self.assertEqual(y, data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.identity(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [data, data])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_identity_cuda(self):
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!