Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DefaultGraphStack #175

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion TensorFlowSharp/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,36 @@ internal static void CheckSize ()
}
}

internal class DefaultGraphStack
{
[ThreadStatic]
private static IList<TFGraph> stack;
private static object lockObj = new object();

internal static TFGraph Instance => stack.Last();

internal static void SetGraph(TFGraph graph)
{
if (stack == null)
{
lock (lockObj)
{
if (stack == null)
{
stack = new List<TFGraph>();
}
}
}

stack.Add(graph);
}

internal static void RemoveGraph(TFGraph graph)
{
stack.Remove(graph);
}
}

/// <summary>
/// Base class for many TensorFlow data types that provides a common idiom to dispose and
/// release resources associated with the native data types. Generally, you do not need to use this.
Expand Down Expand Up @@ -473,19 +503,21 @@ public partial class TFGraph : TFDisposable
/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFGraph"/> class.
/// </summary>
public TFGraph () : base (TF_NewGraph ())
public TFGraph () : this (TF_NewGraph ())
{
}

internal TFGraph (IntPtr handle) : base (handle)
{
DefaultGraphStack.SetGraph(this);
}

// extern void TF_DeleteGraph (TF_Graph *);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_DeleteGraph (TF_Graph graph);
internal override void NativeDispose (IntPtr handle)
{
DefaultGraphStack.RemoveGraph(this);
TF_DeleteGraph (handle);
}

Expand Down Expand Up @@ -3140,6 +3172,50 @@ public override string ToString ()
{
return string.Format ("[{3} Index={1} Operation={2} (0x{0:X})]", (long) LLOperation, Index, Operation, OutputType);
}

/// <summary>
/// Plus operation.
/// </summary>
/// <param name="o1">First output</param>
/// <param name="o2">Second output</param>
/// <returns></returns>
public static TFOutput operator + (TFOutput o1, TFOutput o2)
{
return DefaultGraphStack.Instance.Add (o1, o2);
}

/// <summary>
/// Multiplication operation.
/// </summary>
/// <param name="o1">First output</param>
/// <param name="o2">Second output</param>
/// <returns></returns>
public static TFOutput operator * (TFOutput o1, TFOutput o2)
{
return DefaultGraphStack.Instance.Mul (o1, o2);
}

/// <summary>
/// Division operation.
/// </summary>
/// <param name="o1">First output</param>
/// <param name="o2">Second output</param>
/// <returns></returns>
public static TFOutput operator / (TFOutput o1, TFOutput o2)
{
return DefaultGraphStack.Instance.Div (o1, o2);
}

/// <summary>
/// Subtraction operation.
/// </summary>
/// <param name="o1">First output</param>
/// <param name="o2">Second output</param>
/// <returns></returns>
public static TFOutput operator - (TFOutput o1, TFOutput o2)
{
return DefaultGraphStack.Instance.Sub (o1, o2);
}
}

/// <summary>
Expand Down