diff --git a/TensorFlowSharp/Tensorflow.cs b/TensorFlowSharp/Tensorflow.cs index 0f361989..259b39ff 100644 --- a/TensorFlowSharp/Tensorflow.cs +++ b/TensorFlowSharp/Tensorflow.cs @@ -108,6 +108,36 @@ internal static void CheckSize () } } + internal class DefaultGraphStack + { + [ThreadStatic] + private static IList stack; + private static object lockObj = new object(); + + internal static TFGraph Instance => stack.Last(); + + internal static void SetGraph(TFGraph graph) + { + if (stack == null) + { + lock (lockObj) + { + if (stack == null) + { + stack = new List(); + } + } + } + + stack.Add(graph); + } + + internal static void RemoveGraph(TFGraph graph) + { + stack.Remove(graph); + } + } + /// /// Base class for many TensorFlow data types that provides a common idiom to dispose and /// release resources associated with the native data types. Generally, you do not need to use this. @@ -473,12 +503,13 @@ public partial class TFGraph : TFDisposable /// /// Initializes a new instance of the class. /// - public TFGraph () : base (TF_NewGraph ()) + public TFGraph () : this (TF_NewGraph ()) { } internal TFGraph (IntPtr handle) : base (handle) { + DefaultGraphStack.SetGraph(this); } // extern void TF_DeleteGraph (TF_Graph *); @@ -486,6 +517,7 @@ internal TFGraph (IntPtr handle) : base (handle) static extern unsafe void TF_DeleteGraph (TF_Graph graph); internal override void NativeDispose (IntPtr handle) { + DefaultGraphStack.RemoveGraph(this); TF_DeleteGraph (handle); } @@ -3140,6 +3172,50 @@ public override string ToString () { return string.Format ("[{3} Index={1} Operation={2} (0x{0:X})]", (long) LLOperation, Index, Operation, OutputType); } + + /// + /// Plus operation. + /// + /// First output + /// Second output + /// + public static TFOutput operator + (TFOutput o1, TFOutput o2) + { + return DefaultGraphStack.Instance.Add (o1, o2); + } + + /// + /// Multiplication operation. + /// + /// First output + /// Second output + /// + public static TFOutput operator * (TFOutput o1, TFOutput o2) + { + return DefaultGraphStack.Instance.Mul (o1, o2); + } + + /// + /// Division operation. + /// + /// First output + /// Second output + /// + public static TFOutput operator / (TFOutput o1, TFOutput o2) + { + return DefaultGraphStack.Instance.Div (o1, o2); + } + + /// + /// Subtraction operation. + /// + /// First output + /// Second output + /// + public static TFOutput operator - (TFOutput o1, TFOutput o2) + { + return DefaultGraphStack.Instance.Sub (o1, o2); + } } ///