Commit e82d2ba4 by Ting PAN

Fix the circular reference of GradientTape

Summary:
This commit removes the unused reference of GradientTape to
avoid the circular reference issue.
1 parent 9de0f1a3
#include "dragon/core/workspace.h"
#include "dragon/operators/array/reshape_ops.h" #include "dragon/operators/array/reshape_ops.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
...@@ -17,6 +18,11 @@ DEPLOY_CPU_OPERATOR(Identity); ...@@ -17,6 +18,11 @@ DEPLOY_CPU_OPERATOR(Identity);
DEPLOY_CUDA_OPERATOR(Identity); DEPLOY_CUDA_OPERATOR(Identity);
#endif #endif
DEPLOY_CPU_OPERATOR(IdentityGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(IdentityGradient);
#endif
OPERATOR_SCHEMA(Identity) OPERATOR_SCHEMA(Identity)
/* X */ /* X */
.NumInputs(1) .NumInputs(1)
......
...@@ -30,9 +30,10 @@ from dragon.core.util import tls ...@@ -30,9 +30,10 @@ from dragon.core.util import tls
class Tape(object): class Tape(object):
def __init__(self, parent): """Tape instance."""
def __init__(self):
self._defs = [] self._defs = []
self._parent = parent
self._watched = set() self._watched = set()
self._empty_grads = set() self._empty_grads = set()
self._gc = workspace.get_workspace().collectors self._gc = workspace.get_workspace().collectors
...@@ -137,6 +138,8 @@ class GradientTape(object): ...@@ -137,6 +138,8 @@ class GradientTape(object):
raise RuntimeError( raise RuntimeError(
'GradientTape.gradient(...) can only be called ' 'GradientTape.gradient(...) can only be called '
'once on non-persistent tapes.') 'once on non-persistent tapes.')
# Stop recording if not persistent.
if self._recording: if self._recording:
if not self._persistent: if not self._persistent:
self._pop_tape() self._pop_tape()
...@@ -212,7 +215,7 @@ class GradientTape(object): ...@@ -212,7 +215,7 @@ class GradientTape(object):
if self._recording: if self._recording:
raise ValueError('Tape is already recording.') raise ValueError('Tape is already recording.')
if self._tape is None: if self._tape is None:
self._tape = Tape(self) self._tape = Tape()
push_new_tape(self._tape) push_new_tape(self._tape)
self._recording = True self._recording = True
......
...@@ -660,8 +660,11 @@ class TestArrayOps(OpTestCase): ...@@ -660,8 +660,11 @@ class TestArrayOps(OpTestCase):
with execution_context().mode(execution): with execution_context().mode(execution):
data = arange((4,)) data = arange((4,))
x = new_tensor(data) x = new_tensor(data)
y = dragon.identity(x) with dragon.GradientTape() as tape:
self.assertEqual(y, data) tape.watch(x)
y = dragon.identity(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [data, data])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_identity_cuda(self): def test_identity_cuda(self):
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!