diff --git a/README.md b/README.md
index 5a03d13b..c9fb0520 100644
--- a/README.md
+++ b/README.md
@@ -18,8 +18,7 @@ Short overview about why anyone would use this, how it came to be (even shorter)
The recommended way to use the latest version of Sigma is adding the NuGet package to your project.
You can either include the core framework (command line only) [![Nuget (PreRelease)](https://img.shields.io/nuget/vpre/Sigma.Core.svg?style=flat-square)](https://www.nuget.org/packages/Sigma.Core) or the WPF visualiser (only works on Windows) which also references the core framework [![Nuget (PreRelease WPF)](https://img.shields.io/nuget/vpre/Sigma.Core.Monitors.WPF.svg?style=flat-square)](https://www.nuget.org/packages/Sigma.Core.Monitors.WPF).
-In both cases, you can use any project with a main (ConsoleApplication) but you have to change the project settings to x64 since **Sigma only supports 64bit mode**.
-
+In both cases, you can use any project with a main (e.g. ConsoleApplication) but you have to change the project settings to x64 (since **Sigma only supports 64bit mode**) and change the target framework to **.NET 4.6** before installing the NuGet packages.
### From source
diff --git a/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs
index d480525e..a96a0db9 100644
--- a/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs
+++ b/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs
@@ -6,6 +6,7 @@ MIT License
For full license see LICENSE in the root directory of this project.
*/
+using System;
using System.Collections.Generic;
using LiveCharts;
using LiveCharts.Wpf;
@@ -42,15 +43,30 @@ public AccuracyPanel(string title, ITrainer trainer, object headerContent = null
/// The content for the header. If null is passed,
/// the title will be used.
///
- public AccuracyPanel(string title, ITrainer trainer, object headerContent = null, params int[] tops) : base(title, headerContent)
+ public AccuracyPanel(string title, ITrainer trainer, object headerContent = null, params int[] tops) : this(title, trainer, TimeStep.Every(1, TimeScale.Epoch), headerContent, tops)
{
+ }
+
+ ///
+ /// Create an AccuracyPanel with a given title. It displays given accuracies per epoch.
+ /// If a title is not sufficient modify .
+ ///
+ /// The given tile.
+ ///
+ /// The content for the header. If null is passed,
+ /// the title will be used.
+ ///
+ public AccuracyPanel(string title, ITrainer trainer, ITimeStep timeStep, object headerContent = null, params int[] tops) : base(title, headerContent)
+ {
+ if (timeStep == null) throw new ArgumentNullException(nameof(timeStep));
+
// skip the first since its automatically generated
for (int i = 1; i < tops.Length; i++)
{
AddSeries(new LineSeries());
}
- trainer.AddHook(new ChartValidationAccuracyReport(this, "validation", TimeStep.Every(1, TimeScale.Epoch), tops));
+ trainer.AddHook(new ChartValidationAccuracyReport(this, "validation", timeStep, tops));
AxisY.MinValue = 0;
AxisY.MaxValue = 100;
diff --git a/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs
index 333be572..7ac2486d 100644
--- a/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs
+++ b/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs
@@ -48,9 +48,9 @@ namespace Sigma.Core.Monitors.WPF.Panels.Charts
/// The for the hook.
/// The content for the header. If null is passed,
/// the title will be used.
- public TrainerChartPanel(string title, ITrainer trainer, string hookedValue, ITimeStep timestep, object headerContent = null) : base(title, headerContent)
+ public TrainerChartPanel(string title, ITrainer trainer, string hookedValue, ITimeStep timestep, bool averageMode = false, object headerContent = null) : base(title, headerContent)
{
- VisualValueReporterHook hook = new VisualValueReporterHook(this, new[] { hookedValue }, timestep);
+ VisualValueReporterHook hook = new VisualValueReporterHook(this, new[] { hookedValue }, timestep, averageMode);
Init(trainer, hook);
}
@@ -65,9 +65,9 @@ public TrainerChartPanel(string title, ITrainer trainer, string hookedValue, ITi
/// The for the hook.
/// The content for the header. If null is passed,
/// the title will be used.
- public TrainerChartPanel(string title, ITrainer trainer, ITimeStep timestep, object headerContent = null, params string[] hookedValues) : base(title, headerContent)
+ public TrainerChartPanel(string title, ITrainer trainer, ITimeStep timestep, bool averageMode = false, object headerContent = null, params string[] hookedValues) : base(title, headerContent)
{
- VisualValueReporterHook hook = new VisualValueReporterHook(this, hookedValues, timestep);
+ VisualValueReporterHook hook = new VisualValueReporterHook(this, hookedValues, timestep, averageMode);
Init(trainer, hook);
}
@@ -118,7 +118,7 @@ protected class VisualValueReporterHook : ValueReporterHook
/// The chartpanel to which points will get added.
/// The identifiers for the ; these values will get plotted.
/// The for the hook (i.e. execution definition).
- public VisualValueReporterHook(ChartPanel chartPanel, string[] valueIdentifiers, ITimeStep timestep) : base(valueIdentifiers, timestep)
+ public VisualValueReporterHook(ChartPanel chartPanel, string[] valueIdentifiers, ITimeStep timestep, bool averageMode = false) : base(valueIdentifiers, timestep, averageMode, false)
{
ParameterRegistry[ChartPanelIdentifier] = chartPanel;
}
@@ -127,7 +127,10 @@ public VisualValueReporterHook(ChartPanel
/// Report the values for a certain epoch / iteration to a passed ChartPanel.
///
/// The values by their identifier.
- protected override void ReportValues(IDictionary valuesByIdentifier)
+ /// A boolean indicating whether or not to report the current epoch / iteration.
+ /// The current epoch.
+ /// The current iteration.
+ protected override void ReportValues(IDictionary valuesByIdentifier, bool reportEpochIteration, int epoch, int iteration)
{
ChartPanel chartPanel = (ChartPanel) ParameterRegistry[ChartPanelIdentifier];
chartPanel.Add((TData) valuesByIdentifier.Values.First());
diff --git a/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs
index ca4f677d..f5790190 100644
--- a/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs
+++ b/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs
@@ -6,10 +6,17 @@ MIT License
For full license see LICENSE in the root directory of this project.
*/
-using System.Windows;
-using System.Windows.Controls;
using Sigma.Core.Monitors.WPF.View.CustomControls.Panels.Control;
+using Sigma.Core.Monitors.WPF.View.Parameterisation;
+using Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults;
+using Sigma.Core.Monitors.WPF.View.Windows;
using Sigma.Core.Training;
+using Sigma.Core.Training.Hooks.Reporters;
+using Sigma.Core.Utils;
+using System;
+using System.Collections.Generic;
+using System.Windows;
+using System.Windows.Controls;
namespace Sigma.Core.Monitors.WPF.Panels.Controls
{
@@ -19,7 +26,8 @@ namespace Sigma.Core.Monitors.WPF.Panels.Controls
///
public class ControlPanel : GenericPanel
{
- private readonly SigmaPlaybackControl _playbackControl;
+ private SigmaPlaybackControl _playbackControl;
+ private ParameterView _parameterView;
private ITrainer _trainer;
@@ -36,6 +44,18 @@ public ITrainer Trainer
}
}
+ ///
+ /// This list stores all trainers that have been initialised.
+ /// Required to only add one hook per trainer.
+ ///
+ private static readonly IList Trainers;
+
+ static ControlPanel()
+ {
+ Trainers = new List();
+ }
+
+
public ControlPanel(string title, object content = null) : this(title, null, content) { }
public ControlPanel(string title, ITrainer trainer, object content = null) : base(title, content)
@@ -49,10 +69,61 @@ public ControlPanel(string title, ITrainer trainer, object content = null) : bas
HorizontalAlignment = HorizontalAlignment.Center,
Margin = new Thickness(0, 20, 0, 0)
};
+ }
+
+ ///
+ /// This method will be called once the window is initialising (after it has been added).
+ /// Do not store a reference of the window unless you properly dispose it (remove reference once not required).
+ ///
+ /// The wpf window this panel will be added to.
+ protected override void OnInitialise(WPFWindow window)
+ {
+ throw new InvalidOperationException($"{nameof(ControlPanel)} is only compatible with {nameof(SigmaWindow)}s.");
+ }
- _playbackControl = new SigmaPlaybackControl { Trainer = Trainer };
+ ///
+ /// This method will be called after the panel has been added (window, monitor set...)
+ ///
+ protected override void OnInitialise(SigmaWindow window)
+ {
+ if (!Trainers.Contains(Trainer))
+ {
+ ValueSourceReporterHook valueHook = new ValueSourceReporterHook(TimeStep.Every(1, TimeScale.Epoch), "runtime_millis");
+ _trainer.AddGlobalHook(valueHook);
+ Monitor.Sigma.SynchronisationHandler.AddSynchronisationSource(valueHook);
+ Trainers.Add(Trainer);
+
+ valueHook = new ValueSourceReporterHook(TimeStep.Every(1, TimeScale.Iteration), "iteration");
+ _trainer.AddLocalHook(valueHook);
+ Monitor.Sigma.SynchronisationHandler.AddSynchronisationSource(valueHook);
+ }
+
+ //TODO: style?
+ _playbackControl = new SigmaPlaybackControl { Trainer = Trainer, Margin = new Thickness(0, 0, 0, 20), HorizontalAlignment = HorizontalAlignment.Center};
Content.Children.Add(_playbackControl);
+
+ _parameterView = new ParameterView(Monitor.Sigma, window);
+
+ //TODO: language support
+
+ SigmaTextBlock timeBox = (SigmaTextBlock) _parameterView.Add("Running time", typeof(object), _trainer.Operator.Registry, "runtime_millis");
+ timeBox.AutoPollValues(_trainer, TimeStep.Every(1, TimeScale.Epoch));
+ timeBox.Postfix = " ms";
+
+ UserControlParameterVisualiser epochBox = (UserControlParameterVisualiser) _parameterView.Add("Current epoch", typeof(object), _trainer.Operator.Registry, "epoch");
+ epochBox.AutoPollValues(_trainer, TimeStep.Every(1, TimeScale.Epoch));
+
+ UserControlParameterVisualiser iterationBox = (UserControlParameterVisualiser) _parameterView.Add("Current iteration", typeof(object), _trainer.Operator.Registry, "iteration");
+ iterationBox.AutoPollValues(_trainer, TimeStep.Every(1, TimeScale.Iteration));
+
+ IRegistry registry = new Registry
+ {
+ { "op", Trainer.Operator.GetType().Name }
+ };
+ _parameterView.Add("Current operator", typeof(object), registry, "op");
+
+ Content.Children.Add(_parameterView);
}
}
}
diff --git a/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs
index c256afdc..ffa2cd40 100644
--- a/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs
+++ b/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs
@@ -108,7 +108,7 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
});
DataProviderUtils.ProvideExternalInputData(provider, network, block);
- network.Run(Operator.Handler, false);
+ network.Run(Operator.Handler, trainingPass: false);
DataProviderUtils.ProvideExternalOutputData(provider, network, block);
}
}
diff --git a/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs
index edf1a8ef..ea956598 100644
--- a/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs
+++ b/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs
@@ -9,6 +9,7 @@ For full license see LICENSE in the root directory of this project.
using System;
using Sigma.Core.Monitors.Synchronisation;
using Sigma.Core.Monitors.WPF.View.Parameterisation;
+using Sigma.Core.Monitors.WPF.View.Windows;
using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation;
using Sigma.Core.Utils;
@@ -27,6 +28,7 @@ public ParameterPanel(string title, IParameterVisualiserManager visualiserManage
{
Content = new ParameterView(visualiserManager, synchronisationHandler);
}
+ public ParameterPanel(string title, SigmaEnvironment environment, SigmaWindow window, object headerContent = null) : this(title, window.ParameterVisualiser, environment.SynchronisationHandler, headerContent) { }
public void Add(string name, Type type, IRegistry registry, string key)
{
diff --git a/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs b/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs
index 73028d60..5fe123e4 100644
--- a/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs
+++ b/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs
@@ -9,6 +9,7 @@ For full license see LICENSE in the root directory of this project.
using System.Windows;
using System.Windows.Controls;
using MaterialDesignThemes.Wpf;
+using Sigma.Core.Monitors.WPF.View.Windows;
// ReSharper disable VirtualMemberCallInConstructor
@@ -38,6 +39,11 @@ public abstract class SigmaPanel : Card
///
private UIElement _content;
+ ///
+ /// Currently responsible monitor - it will be automatically set when adding a new panel. (null until )
+ ///
+ public WPFMonitor Monitor { get; set; }
+
///
/// Create a SigmaPanel with a given title.
/// If a title is not sufficient modify .
@@ -99,7 +105,7 @@ protected SigmaPanel(string title, object content = null)
}
else
{
- _content = value as UIElement ?? new Label { Content = value.ToString() };
+ _content = value as UIElement ?? new Label {Content = value.ToString()};
ContentGrid.Children.Add(_content);
}
}
@@ -143,10 +149,10 @@ protected virtual Grid CreateHeader(object content)
{
Grid header = new Grid();
- header.RowDefinitions.Add(new RowDefinition { Height = new GridLength(1, GridUnitType.Auto) });
- header.ColumnDefinitions.Add(new ColumnDefinition { Width = new GridLength(1, GridUnitType.Auto) });
+ header.RowDefinitions.Add(new RowDefinition {Height = new GridLength(1, GridUnitType.Auto)});
+ header.ColumnDefinitions.Add(new ColumnDefinition {Width = new GridLength(1, GridUnitType.Auto)});
- Label headerContent = new Label { Content = content };
+ Label headerContent = new Label {Content = content};
header.Children.Add(headerContent);
header.SetResourceReference(BackgroundProperty, "SigmaPanelHeaderBackground");
@@ -174,19 +180,54 @@ protected virtual Grid CreateContentGrid()
{
Grid grid = new Grid();
- grid.RowDefinitions.Add(new RowDefinition { Height = new GridLength(1, GridUnitType.Star) });
- grid.ColumnDefinitions.Add(new ColumnDefinition { Width = new GridLength(1, GridUnitType.Star) });
+ grid.RowDefinitions.Add(new RowDefinition {Height = new GridLength(1, GridUnitType.Star)});
+ grid.ColumnDefinitions.Add(new ColumnDefinition {Width = new GridLength(1, GridUnitType.Star)});
return grid;
}
+ ///
+ /// This method invokes the initialisation of the panel (after it has been addded).
+ ///
+ public void Initialise(WPFWindow window)
+ {
+ if (window is SigmaWindow)
+ {
+ OnInitialise((SigmaWindow)window);
+ }
+ else
+ {
+ OnInitialise(window);
+ }
+ }
+
+ ///
+ /// This method will be called once the window is initialising (after it has been added).
+ /// Do not store a reference of the window unless you properly dispose it (remove reference once not required).
+ ///
+ /// The wpf window this panel will be added to.
+ protected virtual void OnInitialise(WPFWindow window)
+ {
+
+ }
+
+ ///
+ /// This method will be called once the window is initialising (after it has been added).
+ /// Do not store a reference of the window unless you properly dispose it (remove reference once not required).
+ ///
+ /// The wpf window this panel will be added to.
+ protected virtual void OnInitialise(SigmaWindow window)
+ {
+
+ }
+
///
/// Create the default panel in which every other element is contained.
///
/// The newly create .
protected virtual DockPanel CreateDockPanel()
{
- return new DockPanel { LastChildFill = true, Margin = new Thickness(-1, 0, 0, 0) };
+ return new DockPanel {LastChildFill = true, Margin = new Thickness(-1, 0, 0, 0)};
}
}
}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj b/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj
index 13638d38..8010611e 100644
--- a/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj
+++ b/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj
@@ -168,6 +168,7 @@
+
@@ -222,6 +223,8 @@
SigmaComboBox.xaml
+
+
SigmaSlider.xaml
@@ -233,6 +236,7 @@
+
ParameterView.xaml
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaDynamicGenericBox.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaDynamicGenericBox.cs
new file mode 100644
index 00000000..7e11bf59
--- /dev/null
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaDynamicGenericBox.cs
@@ -0,0 +1,59 @@
+using System;
+using System.ComponentModel;
+
+namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults
+{
+ ///
+ /// This is a generic text box that automatically uses the type of the registry in order to correctly display and parse to / from the value.
+ ///
+ /// It is necessary for XAML or other cases where generics are not possible.
+ ///
+ public class SigmaDynamicGenericBox : SigmaTextBox
+ {
+ ///
+ /// The converter that converts the given type for the registry
+ ///
+ public TypeConverter Converter { get; protected set; }
+
+ ///
+ /// The current value that is displayed
+ ///
+ public object CurrentValue { get; protected set; }
+
+ ///
+ /// Force the visualiser to update its value (i.e. display the value that is stored).
+ ///
+ public override void Read()
+ {
+ object obj = SynchronisationHandler.SynchroniseGet(Registry, Key);
+
+ if (Converter == null && obj != null)
+ {
+ Converter = TypeDescriptor.GetConverter(obj.GetType());
+ }
+
+ if (obj != null)
+ {
+ CurrentValue = obj;
+ Text = obj.ToString();
+ }
+ }
+
+ ///
+ /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry).
+ ///
+ public override void Write()
+ {
+ try
+ {
+ object convertedVal = Converter.ConvertFromString(Text);
+ Pending = true;
+ SynchronisationHandler.SynchroniseSet(Registry, Key, convertedVal, val => Pending = false, e => Errored = true);
+ }
+ catch (Exception)
+ {
+ Errored = true;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaGenericBox.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaGenericBox.cs
new file mode 100644
index 00000000..652d696b
--- /dev/null
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaGenericBox.cs
@@ -0,0 +1,71 @@
+using System;
+using System.ComponentModel;
+using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation;
+
+namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults
+{
+ ///
+ /// A textbox that can convert from and to an arbitrary object with a type converter.
+ ///
+ /// The type that is currently being represented.
+ [GenericParameterVisualiser(typeof(double), Priority = VisualiserPriority.Lower)]
+ [GenericParameterVisualiser(typeof(float), Priority = VisualiserPriority.Lower)]
+ [GenericParameterVisualiser(typeof(short), Priority = VisualiserPriority.Lower)]
+ [GenericParameterVisualiser(typeof(int), Priority = VisualiserPriority.Lower)]
+ [GenericParameterVisualiser(typeof(long), Priority = VisualiserPriority.Lower)]
+ public class SigmaGenericBox : SigmaTextBox
+ {
+ ///
+ /// The converter that converts the given type for the registry
+ ///
+ public TypeConverter Converter { get; protected set; }
+
+ ///
+ /// The current active value.
+ ///
+ protected T CurrentValue;
+
+ ///
+ /// Create a generic box and initialise the converter.
+ ///
+ public SigmaGenericBox()
+ {
+ Converter = TypeDescriptor.GetConverter(typeof(T));
+ }
+
+ ///
+ /// Force the visualiser to update its value (i.e. display the value that is stored).
+ ///
+ public override void Read()
+ {
+ SynchronisationHandler.SynchroniseUpdate(Registry, Key, CurrentValue, val =>
+ {
+ CurrentValue = val;
+ Text = CurrentValue.ToString();
+ });
+ }
+
+ ///
+ /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry).
+ ///
+ public override void Write()
+ {
+ try
+ {
+ T convertedValue = (T) Converter.ConvertFromString(Text);
+ Pending = true;
+ SynchronisationHandler.SynchroniseSet(Registry, Key, convertedValue, val =>
+ {
+ Pending = false;
+ Errored = false;
+ }, e => Errored = true);
+
+ }
+ catch (Exception)
+ {
+ Errored = true;
+ throw;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml
index e2dbc879..15d9a139 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml
@@ -26,7 +26,7 @@
IsEnabled="{Binding IsReadOnly, Converter={StaticResource InverseBooleanConverter}}"
VerticalAlignment="Center"/>
-
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs
index fffd80fd..a650317a 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs
@@ -6,6 +6,8 @@ MIT License
For full license see LICENSE in the root directory of this project.
*/
+using System.Windows;
+using System.Windows.Controls;
using Sigma.Core.Monitors.Synchronisation;
using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation;
using Sigma.Core.Utils;
@@ -18,18 +20,22 @@ namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults
[ParameterVisualiser(typeof(object), Priority = VisualiserPriority.Lower)]
public partial class SigmaTextBlock
{
- private object _object;
+ ///
+ /// The object that is currently being displayed (without updating the displayed information)
+ ///
+ protected object _Object;
///
/// The object that is being displayed (toString is called).
///
- public object Object
+ public virtual object Object
{
- get { return _object; }
+ get { return _Object; }
set
{
- _object = value;
- TextBlock.Text = value?.ToString() ?? "null";
+ _Object = value;
+ string text = value?.ToString() ?? "null";
+ TextBlock.Text = Prefix + text + Postfix;
}
}
@@ -38,6 +44,29 @@ public object Object
///
public string Text => TextBlock.Text;
+ ///
+ /// This string will be added before the displayed string.
+ ///
+ public string Prefix
+ {
+ get { return (string) GetValue(PrefixProperty); }
+ set { SetValue(PrefixProperty, value); }
+ }
+
+ public static readonly DependencyProperty PrefixProperty =
+ DependencyProperty.Register("Prefix", typeof(string), typeof(SigmaTextBox), new PropertyMetadata(""));
+
+ ///
+ /// This string will be added after the displayed string.
+ ///
+ public string Postfix
+ {
+ get { return (string) GetValue(PostfixProperty); }
+ set { SetValue(PostfixProperty, value); }
+ }
+
+ public static readonly DependencyProperty PostfixProperty =
+ DependencyProperty.Register("Postfix", typeof(string), typeof(SigmaTextBox), new PropertyMetadata(""));
///
/// The fully resolved key to access the synchandler.
@@ -86,7 +115,7 @@ public SigmaTextBlock()
///
public override void Read()
{
- Object = SynchronisationHandler.SynchroniseGet(Registry, Key);
+ SynchronisationHandler.SynchroniseUpdate(Registry, Key, Object, newObj => Dispatcher.Invoke(() => Object = newObj));
}
///
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs
index a1cc1bc2..9dc7c7a4 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs
@@ -15,110 +15,6 @@ For full license see LICENSE in the root directory of this project.
namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults
{
- //[ParameterVisualiser(typeof(float), Priority = VisualiserPriority.Lower)]
- //internal class SigmaFloatBox : DynamicConverterBox
- //{ }
-
- //[ParameterVisualiser(typeof(double), Priority = VisualiserPriority.Lower)]
- //internal class SigmaDoubleBox : DynamicConverterBox
- //{ }
-
- /////
- ///// Sigmas way of converting a string to given value (e.g. double)
- /////
- //public class SigmaConverterBox : SigmaTextBox
- //{
- // ///
- // /// The converter that converts the given type for the registry
- // ///
- // protected readonly TypeConverter Converter;
-
- // public SigmaConverterBox()
- // {
- // Converter = TypeDescriptor.GetConverter(typeof(T));
- // }
-
- // ///
- // /// Force the visualiser to update its value (i.e. display the value that is stored).
- // ///
- // public override void Read()
- // {
- // Text = SynchronisationHandler.SynchroniseGet(Registry, Key).ToString();
- // }
-
- // ///
- // /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry).
- // ///
- // public override void Write()
- // {
- // try
- // {
- // T num = (T) Converter.ConvertFromString(Text);
- // Pending = true;
- // SynchronisationHandler.SynchroniseSet(Registry, Key, num, val => Pending = false, e => Errored = true);
- // }
- // catch (Exception)
- // {
- // Errored = true;
- // }
- // }
- //}
- [ParameterVisualiser(typeof(float), Priority = VisualiserPriority.Lower)]
- [ParameterVisualiser(typeof(double), Priority = VisualiserPriority.Lower)]
- [ParameterVisualiser(typeof(long), Priority = VisualiserPriority.Lower)]
- [ParameterVisualiser(typeof(int), Priority = VisualiserPriority.Lower)]
- [ParameterVisualiser(typeof(short), Priority = VisualiserPriority.Lower)]
- public class DynamicConverterBox : SigmaTextBox
- {
- ///
- /// The converter that converts the given type for the registry
- ///
- public TypeConverter Converter { get; protected set; }
-
- public object CurrentValue { get; protected set; }
-
- ///
- /// Force the visualiser to update its value (i.e. display the value that is stored).
- ///
- public override void Read()
- {
- object obj = SynchronisationHandler.SynchroniseGet(Registry, Key);
-
- if (Converter == null && obj != null)
- {
- Converter = TypeDescriptor.GetConverter(obj.GetType());
- }
-
- if (obj != null)
- {
- CurrentValue = obj;
- Text = obj.ToString();
- }
- }
-
- ///
- /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry).
- ///
- public override void Write()
- {
- try
- {
- object convertedVal = Converter.ConvertFromString(Text);
- Pending = true;
- SynchronisationHandler.SynchroniseSet(Registry, Key, convertedVal, val => Pending = false, e => Errored = true);
- }
- catch (Exception)
- {
- Errored = true;
-
-#if DEBUG
- //TODO: such an ugly hack only for testing
- Read();
-#endif
- }
- }
- }
-
///
/// Sigmas way of displaying strings.
///
@@ -151,7 +47,7 @@ public partial class SigmaTextBox
public string Text
{
get { return TextBox.Text; }
- set { TextBox.Text = value; }
+ set { Dispatcher.Invoke(() => TextBox.Text = value); }
}
///
@@ -181,7 +77,7 @@ public SigmaTextBox()
///
public override void Read()
{
- Text = SynchronisationHandler.SynchroniseGet(Registry, Key);
+ SynchronisationHandler.SynchroniseUpdate(Registry, Key, Text, newVal => Text = newVal);
}
///
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTimeBlock.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTimeBlock.cs
new file mode 100644
index 00000000..b2ea855b
--- /dev/null
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTimeBlock.cs
@@ -0,0 +1,60 @@
+using System;
+using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation;
+
+namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults
+{
+ //TODO: editable time box (timepicker)
+ //[ParameterVisualiser(typeof(DateTime), Priority = VisualiserPriority.Lower)]
+ //public class SigmaTimeBox : SigmaTextBox
+ //{
+
+ //}
+
+ ///
+ /// A TimeBlock that allows to display the current time.
+ ///
+ [ParameterVisualiser(typeof(DateTime), Priority = VisualiserPriority.Lower)]
+ public class SigmaTimeBlock : SigmaTextBlock
+ {
+ ///
+ public override object Object
+ {
+ get { return _Object; }
+ set
+ {
+ if (_Object is DateTime)
+ {
+ _Object = value;
+ TextBlock.Text = ((DateTime) Object).ToString(FormatString);
+ }
+ else
+ {
+ base.Object = value;
+ }
+ }
+ }
+
+ ///
+ /// The string that is used to format the time.
+ /// null , if default formatting should be applied.
+ ///
+ public string FormatString { get; set; }
+
+ ///
+ /// Create a label that is capable of displaying a time.
+ ///
+ public SigmaTimeBlock()
+ {
+ base.IsReadOnly = true;
+ }
+
+ ///
+ /// Determines whether the parameter is editable or not.
+ ///
+ public sealed override bool IsReadOnly
+ {
+ get { return base.IsReadOnly; }
+ set { throw new NotImplementedException(); }
+ }
+ }
+}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs
index a04a13df..d00adbcf 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs
@@ -8,10 +8,11 @@ For full license see LICENSE in the root directory of this project.
using System;
using System.Windows;
-using System.Windows.Annotations;
using System.Windows.Controls;
using log4net;
using Sigma.Core.Monitors.Synchronisation;
+using Sigma.Core.Monitors.WPF.Annotations;
+using Sigma.Core.Monitors.WPF.View.Windows;
using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation;
using Sigma.Core.Utils;
@@ -24,11 +25,26 @@ public partial class ParameterView
{
private readonly ILog _log = LogManager.GetLogger(typeof(ParameterView));
+ ///
+ /// The manager that keeps track of all visualisers.
+ ///
protected readonly IParameterVisualiserManager Manager;
+
+ ///
+ /// The handler that is used for transactions with the monitor.
+ ///
protected readonly ISynchronisationHandler SynchronisationHandler;
+ ///
+ /// The current rowpost to which elements will be added.
+ ///
protected int RowPos;
+ ///
+ /// Generate a new parameter view with a given manager and synchronisation handler.
+ ///
+ /// A manager of all active visualisers.
+ /// A handler for parameter syncing.
public ParameterView(IParameterVisualiserManager manager, ISynchronisationHandler synchronisationHandler)
{
Manager = manager;
@@ -36,14 +52,31 @@ public ParameterView(IParameterVisualiserManager manager, ISynchronisationHandle
InitializeComponent();
}
- public void Add(string name, Type type, IRegistry registry, string key)
+ ///
+ /// Generate a new parameter view with the manager and synchronisation handler assigned in the environment / window.
+ /// The currently active environment.
+ /// The currently active window (i.e. root window).
+ ///
+ public ParameterView(SigmaEnvironment environment, SigmaWindow window) : this(window.ParameterVisualiser, environment.SynchronisationHandler)
+ {
+ }
+
+ ///
+ /// Display a given type stored in given registry (with the given key) next to a label with a given text.
+ ///
+ /// The text the label will contain.
+ /// The type that will be displayed.
+ /// The registry which contains the value that should be displayed.
+ /// The key to access the exact value required.
+ public IParameterVisualiser Add(string name, Type type, IRegistry registry, string key)
{
- Add(new Label { Content = name }, type, registry, key);
+ return Add(new Label {Content = name}, type, registry, key);
}
- public void Add(UIElement name, Type visualiserType, IRegistry registry, string key)
+ public IParameterVisualiser Add(UIElement name, Type visualiserType, IRegistry registry, string key)
{
- UIElement displayer = (UIElement) Activator.CreateInstance(Manager.VisualiserType(visualiserType));
+ //UIElement displayer = (UIElement) Activator.CreateInstance(Manager.VisualiserType(visualiserType));
+ UIElement displayer = (UIElement) Manager.InstantiateVisualiser(visualiserType);
IParameterVisualiser visualiser = displayer as IParameterVisualiser;
if (visualiser == null)
@@ -52,11 +85,13 @@ public void Add(UIElement name, Type visualiserType, IRegistry registry, string
}
Add(name, displayer, visualiser, registry, key);
+
+ return visualiser;
}
public void Add(string name, object visualiserAndDisplayer, IRegistry registry, string key)
{
- Add(new Label { Content = name }, visualiserAndDisplayer, registry, key);
+ Add(new Label {Content = name}, visualiserAndDisplayer, registry, key);
}
public void Add(UIElement name, object visualiserAndDisplayer, IRegistry registry, string key)
@@ -74,16 +109,16 @@ public void Add(UIElement name, object visualiserAndDisplayer, IRegistry registr
}
///
- ///
+ /// Add a that contains information (e.g. the name of the object), and display it with a given object (e.g. the object to interact with).
///
- ///
+ /// The element taht displays information about the elment being displayed (e.g. descriptive name).
/// The element that displays the object in the cell (normally the same as ).
/// The object that is responsible for the link with a variable (normally the same as ).
- ///
- ///
- public void Add(UIElement name, UIElement displayer, IParameterVisualiser visualiser, IRegistry registry, string key)
+ /// The registry which contains the value that should be displayed. May or may not be null (depending on the visualiser).
+ /// The key to access the exact value required. May or may not be null (depending on the visualiser).
+ public void Add([CanBeNull] UIElement name, [CanBeNull] UIElement displayer, [CanBeNull] IParameterVisualiser visualiser, IRegistry registry, string key)
{
- Content.RowDefinitions.Add(new RowDefinition { Height = GridLength.Auto });
+ Content.RowDefinitions.Add(new RowDefinition {Height = GridLength.Auto});
if (visualiser != null)
{
@@ -93,6 +128,7 @@ public void Add(UIElement name, UIElement displayer, IParameterVisualiser visual
visualiser.Read();
}
+ // add the name to the left
if (name != null)
{
Grid.SetColumn(name, 0);
@@ -109,5 +145,8 @@ public void Add(UIElement name, UIElement displayer, IParameterVisualiser visual
RowPos++;
}
+
+
+
}
}
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs
index 5ab287e4..447caf94 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs
@@ -33,10 +33,33 @@ public class ParameterVisualiserAttribute : Attribute, IParameterVisualiserInfo
///
/// Define that the class visualises given type.
///
- ///
+ /// The type that is being represented.
public ParameterVisualiserAttribute(Type type)
{
Type = type;
}
+
+ ///
+ /// Determinse whether the given visualiser is generic or not.
+ ///
+ public bool IsGeneric { get; protected set; }
+ }
+
+ ///
+ /// This marks an . It contains information which type this visualiser implements
+ /// and reduces the amount of work required to define a new type. Differently from , the class implementing this
+ /// attribute has to be a generic class with a single attribute, which will be the given type.
+ /// Multiple attributes can be specified (to display string s and object s for example).
+ ///
+ public class GenericParameterVisualiserAttribute : ParameterVisualiserAttribute
+ {
+ ///
+ /// Define that the class visualises given type.
+ ///
+ /// The type that is being represented.
+ public GenericParameterVisualiserAttribute(Type type) : base(type)
+ {
+ IsGeneric = true;
+ }
}
}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs
index bcbccffc..fdfbb7e5 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs
@@ -18,25 +18,24 @@ namespace Sigma.Core.Monitors.WPF.View.Parameterisation
///
public class ParameterVisualiserInfo : IParameterVisualiserInfo
{
- ///
- /// The type this visualiser visualises.
- ///
+ ///
public Type Type { get; }
- ///
- /// The priority of the . If another priority with a lower priority has already been added, the
- /// higher priority will override the settings.
- ///
+ ///
public VisualiserPriority Priority { get; set; }
+ ///
+ public bool IsGeneric { get; }
+
///
/// Initializes a new instance of the class.
///
/// The type the is responsible for.
/// The priority of the info. (higher prioriuty overrides lower ones).
+ /// Determinse whether the given visualiser is generic or not.
/// If bad enum is passed.
/// If is null .
- public ParameterVisualiserInfo([NotNull] Type type, VisualiserPriority priority = VisualiserPriority.Normal)
+ public ParameterVisualiserInfo([NotNull] Type type, VisualiserPriority priority = VisualiserPriority.Normal, bool isGeneric = false)
{
if (type == null) throw new ArgumentNullException(nameof(type));
if (!Enum.IsDefined(typeof(VisualiserPriority), priority))
@@ -46,6 +45,9 @@ public ParameterVisualiserInfo([NotNull] Type type, VisualiserPriority priority
Type = type;
Priority = priority;
+ IsGeneric = isGeneric;
}
+
+
}
}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs
index fef2ae29..4b5f2627 100644
--- a/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs
+++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs
@@ -6,9 +6,12 @@ MIT License
For full license see LICENSE in the root directory of this project.
*/
+using System;
using System.Windows.Controls;
using Sigma.Core.Monitors.Synchronisation;
using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation;
+using Sigma.Core.Training;
+using Sigma.Core.Training.Hooks;
using Sigma.Core.Utils;
namespace Sigma.Core.Monitors.WPF.View.Parameterisation
@@ -62,5 +65,28 @@ public abstract class UserControlParameterVisualiser : UserControl, IParameterVi
/// Force the visualiser to store its value (i.e. write the value that is displayed to the registry).
///
public abstract void Write();
+
+ ///
+ /// The currently active poll hook that is responsible for updating values.
+ ///
+ protected PollParameterHook ActiveHook;
+
+ ///
+ /// Enables the automatic polling of values (call Read on every TimeStep).
+ /// null if no automatic polling should be enabled.
+ ///
+ /// The trainer on which the poll will be performed.
+ /// The TimeStep on when the parameter should update.
+ public virtual void AutoPollValues(ITrainer trainer, ITimeStep step)
+ {
+ if (ActiveHook != null)
+ {
+ trainer.Operator.DetachGlobalHook(ActiveHook);
+ }
+
+ ActiveHook = new PollParameterHook(step, this);
+
+ trainer.AddGlobalHook(ActiveHook);
+ }
}
}
diff --git a/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs b/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs
index 1e615dee..2428a1ba 100644
--- a/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs
+++ b/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs
@@ -171,7 +171,7 @@ public class SigmaWindow : WPFWindow, IDisposable
/// The that is responsible for creation and detection of
/// visualisation elements.
///
- public IParameterVisualiserManager ParameterVisualiser { get; }
+ public IParameterVisualiserManager ParameterVisualiser { get; }
///
/// The prefix-identifier for .
@@ -663,7 +663,7 @@ protected virtual void AddTabs(TabControlUI tabControl, List
{
foreach (string name in names)
{
- tabControl.AddTab(name, new TabUI(name, DefaultGridSize));
+ tabControl.AddTab(name, new TabUI(Monitor, name, DefaultGridSize));
}
}
@@ -675,7 +675,7 @@ public virtual void AddTabs(params string[] tabs)
{
foreach (string tab in tabs)
{
- TabControl.AddTab(tab, new TabUI(tab, DefaultGridSize));
+ TabControl.AddTab(tab, new TabUI(Monitor, tab, DefaultGridSize));
}
}
diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs
index c1b230bd..1f80063f 100644
--- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs
+++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs
@@ -48,11 +48,13 @@ public interface IParameterVisualiser
///
/// Force the visualiser to update its value (i.e. display the value that is stored).
+ /// This function may be called from an arbitrary thread.
///
void Read();
///
/// Force the visualiser to store its value (i.e. write the value that is displayed to the registry).
+ /// This function may be called from an arbitrary thread.
///
void Write();
}
diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs
index 9e82d5e6..5772bd64 100644
--- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs
+++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs
@@ -49,6 +49,13 @@ public interface IParameterVisualiserInfo
///
Type Type { get; }
+ ///
+ /// Determinse whether the given visualiser is generic or not.
+ ///
+ /// More specifically, if the given visualiser should be passed the visualisation type as generic type or not.
+ ///
+ bool IsGeneric { get; }
+
///
/// The priority of the . If another priority with a lower priority has already been added, the
/// higher priority will override the settings.
diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs
index 9dd86e73..9ed90060 100644
--- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs
+++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs
@@ -8,7 +8,6 @@ For full license see LICENSE in the root directory of this project.
using System;
using Sigma.Core.Monitors.WPF.Annotations;
-using Sigma.Core.Monitors.WPF.View.Parameterisation;
using Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults;
namespace Sigma.Core.Monitors.WPF.ViewModel.Parameterisation
@@ -50,5 +49,11 @@ public interface IParameterVisualiserManager
/// The closest type for visualisation. null if not found.
Type VisualiserTypeByReference([NotNull] object obj);
+ ///
+ /// Instantiate a visualiser that can represent given type.
+ ///
+ /// The type that will be visualies
+ /// An instance of a visualiser.
+ IParameterVisualiser InstantiateVisualiser(Type type);
}
}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs
index 5cfe5117..13c55e57 100644
--- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs
+++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs
@@ -40,11 +40,10 @@ public class ParameterVisualiserManager : IParameterVisualiserManager
///
protected readonly Dictionary AttributeMapping;
-
///
/// The default constructor.
///
- /// If true it will automatically add all classes marked with the attribute .
+ /// If true , it will automatically add all classes marked with the attribute or .
public ParameterVisualiserManager(bool autoAssign = true)
{
TypeMapping = new Dictionary();
@@ -53,25 +52,29 @@ public ParameterVisualiserManager(bool autoAssign = true)
if (autoAssign)
{
// ReSharper disable once VirtualMemberCallInConstructor
- AssignMarkedClasses();
+ AssignMarkedClasses(typeof(ParameterVisualiserAttribute), typeof(GenericParameterVisualiserAttribute));
}
}
///
- /// Assign all classes that are marked with .
+ /// Assign all classes that are marked with the given attributes (marker attributes).
+ /// These attributes have to be an .
///
- protected virtual void AssignMarkedClasses()
+ protected virtual void AssignMarkedClasses(params Type[] markerTypes)
{
- // get all classes that have the custom attribute
- IEnumerable classes = AttributeUtils.GetTypesWithAttribute(typeof(ParameterVisualiserAttribute));
-
- foreach (Type @class in classes)
+ foreach (Type type in markerTypes)
{
- ParameterVisualiserAttribute[] visualisers = (ParameterVisualiserAttribute[])Attribute.GetCustomAttributes(@class, typeof(ParameterVisualiserAttribute));
+ // get all classes that have the custom attribute
+ IEnumerable classes = AttributeUtils.GetTypesWithAttribute(type);
- foreach (ParameterVisualiserAttribute visualiser in visualisers)
+ foreach (Type @class in classes)
{
- Add(@class, visualiser);
+ IParameterVisualiserInfo[] visualisers = (IParameterVisualiserInfo[]) Attribute.GetCustomAttributes(@class, type);
+
+ foreach (IParameterVisualiserInfo visualiser in visualisers)
+ {
+ Add(@class, visualiser);
+ }
}
}
}
@@ -108,15 +111,19 @@ public virtual bool Add(Type visualiserClass, IParameterVisualiserInfo parameter
// if the mapping has already been added
if (TypeMapping.TryGetValue(parameterInfo.Type, out storedClass) && AttributeMapping.TryGetValue(parameterInfo.Type, out storedAttribte))
{
- // if the new values have a lower priority, we return false
- if (parameterInfo.Priority <= storedAttribte.Priority)
+ // if the a differnt type is being represented (necessarry for generics)
+ if (!ReferenceEquals(visualiserClass, storedClass))
{
- _log.Warn($"{parameterInfo.Type} is currently visualised by {storedClass.Name}; {visualiserClass.Name} tried to be the visualiser but has a lower priority ({parameterInfo.Priority} <= {storedAttribte.Priority}).");
+ // if the new values have a lower priority, we return false
+ if (parameterInfo.Priority <= storedAttribte.Priority)
+ {
+ _log.Warn($"{parameterInfo.Type} is currently visualised by {storedClass.Name}; {visualiserClass.Name} tried to be the visualiser but has a lower priority ({parameterInfo.Priority} <= {storedAttribte.Priority}).");
- return false;
- }
+ return false;
+ }
- _log.Info($"{parameterInfo.Type} was visualised by {storedClass.Name}; {visualiserClass.Name} has a higher priority and is therefore the new visualiser ({parameterInfo.Priority} > {storedAttribte.Priority}).");
+ _log.Debug($"{parameterInfo.Type} was visualised by {storedClass.Name}; {visualiserClass.Name} has a higher priority and is therefore the new visualiser ({parameterInfo.Priority} > {storedAttribte.Priority}).");
+ }
}
TypeMapping[parameterInfo.Type] = visualiserClass;
@@ -190,5 +197,20 @@ public Type VisualiserTypeByReference(object obj)
return VisualiserType(obj.GetType());
}
+
+
+ ///
+ public IParameterVisualiser InstantiateVisualiser(Type type)
+ {
+ Type visualiserType = VisualiserType(type);
+ IParameterVisualiserInfo info = AttributeMapping[type];
+
+ if (info.IsGeneric)
+ {
+ return (IParameterVisualiser) Activator.CreateInstance(visualiserType.MakeGenericType(type));
+ }
+
+ return (IParameterVisualiser) Activator.CreateInstance(visualiserType);
+ }
}
}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/PollParameterHook.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/PollParameterHook.cs
new file mode 100644
index 00000000..d347c7a1
--- /dev/null
+++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/PollParameterHook.cs
@@ -0,0 +1,38 @@
+using System;
+using Sigma.Core.Training.Hooks;
+using Sigma.Core.Utils;
+
+namespace Sigma.Core.Monitors.WPF.ViewModel.Parameterisation
+{
+ ///
+ /// A hook that updates a given IParameterVisualiser on a given TimeStep.
+ ///
+ [Serializable]
+ public class PollParameterHook : BaseHook
+ {
+ ///
+ /// The identifier for the currently active visualiser.
+ ///
+ protected const string VisualiserIdentifier = "visualiser";
+
+ ///
+ /// Create a hook with a certain time step
+ ///
+ /// The time step.
+ /// The visualisers that will be updated.
+ public PollParameterHook(ITimeStep timestep, IParameterVisualiser visualiser) : base(timestep)
+ {
+ ParameterRegistry[VisualiserIdentifier] = visualiser;
+ }
+
+ ///
+ /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
+ ///
+ /// The registry containing the required values for this hook's execution.
+ /// A helper resolver for complex registry entries (automatically cached).
+ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
+ {
+ ((IParameterVisualiser) ParameterRegistry[VisualiserIdentifier]).Read();
+ }
+ }
+}
\ No newline at end of file
diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs b/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs
index 9fe107bc..b737ce86 100644
--- a/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs
+++ b/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs
@@ -17,6 +17,7 @@ For full license see LICENSE in the root directory of this project.
using Sigma.Core.Monitors.WPF.Model.UI.Windows;
using Sigma.Core.Monitors.WPF.Panels;
using Sigma.Core.Monitors.WPF.View;
+using Sigma.Core.Monitors.WPF.View.Windows;
using Sigma.Core.Monitors.WPF.ViewModel.CustomControls;
using WPFGrid = System.Windows.Controls.Grid;
@@ -39,17 +40,24 @@ public class TabUI : UIWrapper
///
private GridSize _gridSize;
+ ///
+ /// The monitor this tab is assigned to.
+ ///
+ protected WPFMonitor Monitor;
+
///
/// Create a new - this basically is a
/// with additional control.
///
+ /// The monitor this tab is assigned to.
/// The header of the tab (name in the )
///
- /// The . Use
- /// .
+ /// The . Use
+ /// .
///
- public TabUI(string header, GridSize gridsize)
+ public TabUI(WPFMonitor monitor, string header, GridSize gridsize)
{
+ Monitor = monitor;
Content.Header = header;
GridSize = gridsize;
}
@@ -205,8 +213,10 @@ protected virtual void ApplyLegend(SigmaPanel panel, StatusBarLegendInfo legend)
public void AddCumulativePanel(SigmaPanel panel, int rowSpan = 1, int columnSpan = 1,
StatusBarLegendInfo legend = null)
{
+ panel.Monitor = Monitor;
AddCumulativeElement(panel, rowSpan, columnSpan);
ApplyLegend(panel, legend);
+ panel.Initialise(Monitor.Window);
}
///
@@ -266,8 +276,10 @@ public void AddCumulativeElement(UIElement element, int rowSpan = 1, int columnS
public void AddPanel(SigmaPanel panel, int row, int column, int rowSpan = 1, int columnSpan = 1,
StatusBarLegendInfo legend = null)
{
+ panel.Monitor = Monitor;
AddElement(panel, row, column, rowSpan, columnSpan);
ApplyLegend(panel, legend);
+ panel.Initialise(Monitor.Window);
}
///
diff --git a/Sigma.Core/Architecture/INetwork.cs b/Sigma.Core/Architecture/INetwork.cs
index 5666ab18..b99ace47 100644
--- a/Sigma.Core/Architecture/INetwork.cs
+++ b/Sigma.Core/Architecture/INetwork.cs
@@ -34,6 +34,16 @@ public interface INetwork : IDeepCopyable
///
IRegistry Registry { get; }
+ ///
+ /// The computation handler associated with this network, which is used for initialisation and copy operations.
+ ///
+ IComputationHandler AssociatedHandler { get; set; }
+
+ ///
+ /// Indicate if this network was already initialised.
+ ///
+ bool Initialised { get; }
+
///
/// Validate this network (e.g. ensure all connections are correctly assigned and compatible).
///
diff --git a/Sigma.Core/Architecture/Network.cs b/Sigma.Core/Architecture/Network.cs
index 1e2b233a..0c04e528 100644
--- a/Sigma.Core/Architecture/Network.cs
+++ b/Sigma.Core/Architecture/Network.cs
@@ -19,266 +19,284 @@ For full license see LICENSE in the root directory of this project.
namespace Sigma.Core.Architecture
{
- ///
- /// A default implementation of the interface.
- /// Represents a neural network consisting of interconnected neural layers and a network architecture.
- ///
- [Serializable]
- public class Network : INetwork
- {
- ///
- public INetworkArchitecture Architecture { get; set; }
-
- ///
- public string Name { get; }
-
- ///
- public IRegistry Registry { get; }
-
- [NonSerialized]
- private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType);
- private readonly List _orderedLayerBuffers;
- private readonly List _externalInputsLayerBuffers;
- private readonly List _externalOutputsLayerBuffers;
- private List _orderedLayers;
- private IComputationHandler _initialisationHandler;
- private bool _initialised;
-
- ///
- /// Create a network with a certain unique name.
- ///
- /// The name.
- public Network(string name = "unnamed")
- {
- if (name == null) throw new ArgumentNullException(nameof(name));
-
- Name = name;
- Registry = new Registry(tags: "network");
- _orderedLayerBuffers = new List();
- _externalInputsLayerBuffers = new List();
- _externalOutputsLayerBuffers = new List();
- }
-
- ///
- public virtual object DeepCopy()
- {
- Network copy = new Network(Name);
- copy.Architecture = (INetworkArchitecture) Architecture.DeepCopy();
-
- if (_initialised)
- {
- copy.Initialise(_initialisationHandler);
-
- for (int i = 0; i < _orderedLayerBuffers.Count; i++)
- {
- InternalLayerBuffer originalBuffer = _orderedLayerBuffers[i];
- InternalLayerBuffer copyBuffer = copy._orderedLayerBuffers[i];
-
- foreach (string parameterIdentifier in originalBuffer.Parameters.Keys.ToArray())
- {
- object value = originalBuffer.Parameters[parameterIdentifier];
- IDeepCopyable deepCopyableValue = value as IDeepCopyable;
- object copiedValue;
-
- // copy and copy efficiently by any means possible
- if (deepCopyableValue == null)
- {
- ICloneable cloneableValue = value as ICloneable;
- copiedValue = cloneableValue?.Clone() ?? value;
- }
- else
- {
- INDArray asNDArray = value as INDArray;
-
- if (asNDArray != null)
- {
- _initialisationHandler.Fill(asNDArray, copyBuffer.Parameters.Get(parameterIdentifier));
- }
-
- copiedValue = deepCopyableValue.DeepCopy();
- }
-
- copyBuffer.Parameters[parameterIdentifier] = copiedValue;
- }
- }
- }
-
- return copy;
- }
-
- ///
- public void Validate()
- {
- if (Architecture == null)
- {
- throw new InvalidOperationException("Cannot validate network before assigning a network architecture.");
- }
-
- Architecture.Validate();
- }
-
- ///
- public void Initialise(IComputationHandler handler)
- {
- if (handler == null) throw new ArgumentNullException(nameof(handler));
-
- if (Architecture == null)
- {
- throw new InvalidOperationException("Cannot initialise network before assigning a network architecture.");
- }
+ ///
+ /// A default implementation of the interface.
+ /// Represents a neural network consisting of interconnected neural layers and a network architecture.
+ ///
+ [Serializable]
+ public class Network : INetwork
+ {
+ ///
+ public INetworkArchitecture Architecture { get; set; }
+
+ ///
+ public string Name { get; }
+
+ ///
+ public IRegistry Registry { get; }
+
+ ///
+ /// The computation handler associated with this network, which is used for initialisation and copy operations.
+ /// Note: Set this
+ ///
+ public IComputationHandler AssociatedHandler
+ {
+ get { return _associatedHandler; }
+ set { _associatedHandler = value; }
+ }
+
+ ///
+ public bool Initialised { get { return _initialised; } }
+
+ [NonSerialized]
+ private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType);
+ private readonly List _orderedLayerBuffers;
+ private readonly List _externalInputsLayerBuffers;
+ private readonly List _externalOutputsLayerBuffers;
+ private List _orderedLayers;
+
+ [NonSerialized]
+ private IComputationHandler _associatedHandler;
+ private bool _initialised;
+
+ ///
+ /// Create a network with a certain unique name.
+ ///
+ /// The name.
+ public Network(string name = "unnamed")
+ {
+ if (name == null) throw new ArgumentNullException(nameof(name));
+
+ Name = name;
+ Registry = new Registry(tags: "network");
+ _orderedLayerBuffers = new List();
+ _externalInputsLayerBuffers = new List();
+ _externalOutputsLayerBuffers = new List();
+ }
+
+ ///
+ public virtual object DeepCopy()
+ {
+ Network copy = new Network(Name);
+ copy.Architecture = (INetworkArchitecture) Architecture.DeepCopy();
+
+ if (_initialised)
+ {
+ copy.Initialise(_associatedHandler);
+
+ for (int i = 0; i < _orderedLayerBuffers.Count; i++)
+ {
+ InternalLayerBuffer originalBuffer = _orderedLayerBuffers[i];
+ InternalLayerBuffer copyBuffer = copy._orderedLayerBuffers[i];
+
+ foreach (string parameterIdentifier in originalBuffer.Parameters.Keys.ToArray())
+ {
+ object value = originalBuffer.Parameters[parameterIdentifier];
+ IDeepCopyable deepCopyableValue = value as IDeepCopyable;
+ object copiedValue;
+
+ // copy and copy efficiently by any means possible
+ if (deepCopyableValue == null)
+ {
+ ICloneable cloneableValue = value as ICloneable;
+ copiedValue = cloneableValue?.Clone() ?? value;
+ }
+ else
+ {
+ INDArray asNDArray = value as INDArray;
+
+ if (asNDArray != null)
+ {
+ _associatedHandler.Fill(asNDArray, copyBuffer.Parameters.Get(parameterIdentifier));
+ copiedValue = copyBuffer.Parameters.Get(parameterIdentifier);
+ }
+ else
+ {
+ copiedValue = deepCopyableValue.DeepCopy();
+ }
+ }
+
+ copyBuffer.Parameters[parameterIdentifier] = copiedValue;
+ }
+ }
+ }
+
+ return copy;
+ }
+
+ ///
+ public void Validate()
+ {
+ if (Architecture == null)
+ {
+ throw new InvalidOperationException("Cannot validate network before assigning a network architecture.");
+ }
+
+ Architecture.Validate();
+ }
+
+ ///
+ public void Initialise(IComputationHandler handler)
+ {
+ if (handler == null) throw new ArgumentNullException(nameof(handler));
+
+ if (Architecture == null)
+ {
+ throw new InvalidOperationException("Cannot initialise network before assigning a network architecture.");
+ }
+
+ _logger.Debug($"Initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers...");
+
+ _associatedHandler = handler;
+
+ ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare);
+
+ Architecture.ResolveAllNames();
+
+ _orderedLayerBuffers.Clear();
+ _externalInputsLayerBuffers.Clear();
+ _externalOutputsLayerBuffers.Clear();
+
+ Dictionary, IRegistry> mappedRegistriesByInOutputs = new Dictionary, IRegistry>();
+
+ foreach (LayerConstruct layerConstruct in Architecture.YieldLayerConstructsOrdered())
+ {
+ ILayer layer = layerConstruct.InstantiateLayer(handler);
+
+ Dictionary inputs = new Dictionary();
+
+ foreach (string externalInputAlias in layerConstruct.ExternalInputs)
+ {
+ inputs[externalInputAlias] = new Registry(tags: "external_input");
+ }
- _logger.Debug($"Initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers...");
+ foreach (string inputAlias in layerConstruct.Inputs.Keys)
+ {
+ inputs[inputAlias] = mappedRegistriesByInOutputs[new Tuple(layerConstruct.Inputs[inputAlias], layerConstruct)];
+ }
- _initialisationHandler = handler;
-
- ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare);
-
- Architecture.ResolveAllNames();
-
- _orderedLayerBuffers.Clear();
- _externalInputsLayerBuffers.Clear();
- _externalOutputsLayerBuffers.Clear();
-
- Dictionary, IRegistry> mappedRegistriesByInOutputs = new Dictionary, IRegistry>();
-
- foreach (LayerConstruct layerConstruct in Architecture.YieldLayerConstructsOrdered())
- {
- ILayer layer = layerConstruct.InstantiateLayer(handler);
-
- Dictionary inputs = new Dictionary();
+ Dictionary outputs = new Dictionary();
- foreach (string externalInputAlias in layerConstruct.ExternalInputs)
- {
- inputs[externalInputAlias] = new Registry(tags: "external_input");
- }
+ foreach (string externalOutputAlias in layerConstruct.ExternalOutputs)
+ {
+ outputs[externalOutputAlias] = new Registry(tags: "external_output");
+ }
- foreach (string inputAlias in layerConstruct.Inputs.Keys)
- {
- inputs[inputAlias] = mappedRegistriesByInOutputs[new Tuple(layerConstruct.Inputs[inputAlias], layerConstruct)];
- }
+ foreach (string outputAlias in layerConstruct.Outputs.Keys)
+ {
+ LayerConstruct outputConstruct = layerConstruct.Outputs[outputAlias];
+
+ Tuple inOuTuple = new Tuple(layerConstruct, outputConstruct);
- Dictionary outputs = new Dictionary();
+ Registry outRegistry = new Registry(tags: "internal");
+
+ mappedRegistriesByInOutputs.Add(inOuTuple, outRegistry);
+
+ outputs[outputAlias] = outRegistry;
+ }
- foreach (string externalOutputAlias in layerConstruct.ExternalOutputs)
- {
- outputs[externalOutputAlias] = new Registry(tags: "external_output");
- }
+ InternalLayerBuffer layerBuffer = new InternalLayerBuffer(layer, layerConstruct.Parameters, inputs, outputs,
+ layerConstruct.ExternalInputs, layerConstruct.ExternalOutputs);
+
+ _orderedLayerBuffers.Add(layerBuffer);
+
+ if (layerConstruct.ExternalInputs.Length > 0)
+ {
+ _externalInputsLayerBuffers.Add(layerBuffer);
+ }
- foreach (string outputAlias in layerConstruct.Outputs.Keys)
- {
- LayerConstruct outputConstruct = layerConstruct.Outputs[outputAlias];
+ if (layerConstruct.ExternalOutputs.Length > 0)
+ {
+ _externalOutputsLayerBuffers.Add(layerBuffer);
+ }
+ }
- Tuple inOuTuple = new Tuple(layerConstruct, outputConstruct);
+ _orderedLayers = _orderedLayerBuffers.ConvertAll(buffer => buffer.Layer);
- Registry outRegistry = new Registry(tags: "internal");
+ UpdateRegistry();
- mappedRegistriesByInOutputs.Add(inOuTuple, outRegistry);
+ SigmaEnvironment.TaskManager.EndTask(prepareTask);
- outputs[outputAlias] = outRegistry;
- }
+ _initialised = true;
- InternalLayerBuffer layerBuffer = new InternalLayerBuffer(layer, layerConstruct.Parameters, inputs, outputs,
- layerConstruct.ExternalInputs, layerConstruct.ExternalOutputs);
+ _logger.Debug($"Done initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers.");
+ }
- _orderedLayerBuffers.Add(layerBuffer);
+ protected virtual void UpdateRegistry()
+ {
+ Registry.Clear();
- if (layerConstruct.ExternalInputs.Length > 0)
- {
- _externalInputsLayerBuffers.Add(layerBuffer);
- }
+ Registry["initialised"] = _initialised;
+ Registry["self"] = this;
+ Registry["name"] = Name;
+ Registry["architecture"] = Architecture?.Registry;
- if (layerConstruct.ExternalOutputs.Length > 0)
- {
- _externalOutputsLayerBuffers.Add(layerBuffer);
- }
- }
+ Registry layersRegistry = new Registry(Registry);
+ Registry["layers"] = layersRegistry;
- _orderedLayers = _orderedLayerBuffers.ConvertAll(buffer => buffer.Layer);
+ foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers)
+ {
+ layersRegistry[layerBuffer.Layer.Name] = layerBuffer.Layer.Parameters;
+ }
+ }
- UpdateRegistry();
+ ///
+ public void Run(IComputationHandler handler, bool trainingPass)
+ {
+ if (handler == null) throw new ArgumentNullException(nameof(handler));
- SigmaEnvironment.TaskManager.EndTask(prepareTask);
+ foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers)
+ {
+ layerBuffer.Layer.Run(layerBuffer, handler, trainingPass);
+ }
+ }
- _initialised = true;
+ ///
+ public void Reset()
+ {
+ _logger.Debug($"Resetting network \"{Name}\" to un-initialised state...");
- _logger.Debug($"Done initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers.");
- }
+ _orderedLayerBuffers.Clear();
+ _orderedLayers.Clear();
+ _externalInputsLayerBuffers.Clear();
+ _externalOutputsLayerBuffers.Clear();
- protected virtual void UpdateRegistry()
- {
- Registry.Clear();
+ _initialised = false;
+ _associatedHandler = null;
- Registry["initialised"] = _initialised;
- Registry["self"] = this;
- Registry["name"] = Name;
- Registry["architecture"] = Architecture?.Registry;
+ UpdateRegistry();
- Registry layersRegistry = new Registry(Registry);
- Registry["layers"] = layersRegistry;
+ _logger.Debug($"Done resetting network \"{Name}\". All layer buffer information was discarded.");
+ }
- foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers)
- {
- layersRegistry[layerBuffer.Layer.Name] = layerBuffer.Layer.Parameters;
- }
- }
+ ///
+ public IEnumerable YieldLayersOrdered()
+ {
+ return _orderedLayers;
+ }
- ///
- public void Run(IComputationHandler handler, bool trainingPass)
- {
- if (handler == null) throw new ArgumentNullException(nameof(handler));
+ ///
+ public IEnumerable YieldLayerBuffersOrdered()
+ {
+ return _orderedLayerBuffers;
+ }
- foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers)
- {
- layerBuffer.Layer.Run(layerBuffer, handler, trainingPass);
- }
- }
+ ///
+ public IEnumerable YieldExternalInputsLayerBuffers()
+ {
+ return _externalInputsLayerBuffers;
+ }
- ///
- public void Reset()
- {
- _logger.Debug($"Resetting network \"{Name}\" to un-initialised state...");
-
- _orderedLayerBuffers.Clear();
- _orderedLayers.Clear();
- _externalInputsLayerBuffers.Clear();
- _externalOutputsLayerBuffers.Clear();
-
- _initialised = false;
- _initialisationHandler = null;
-
- UpdateRegistry();
-
- _logger.Debug($"Done resetting network \"{Name}\". All layer buffer information was discarded.");
- }
-
- ///
- public IEnumerable YieldLayersOrdered()
- {
- return _orderedLayers;
- }
-
- ///
- public IEnumerable YieldLayerBuffersOrdered()
- {
- return _orderedLayerBuffers;
- }
-
- ///
- public IEnumerable YieldExternalInputsLayerBuffers()
- {
- return _externalInputsLayerBuffers;
- }
-
- ///
- public IEnumerable YieldExternalOutputsLayerBuffers()
- {
- return _externalOutputsLayerBuffers;
- }
-
- ///
- public INetworkSelector Select()
- {
- return new DefaultNetworkSelector(this);
- }
- }
+ ///
+ public IEnumerable YieldExternalOutputsLayerBuffers()
+ {
+ return _externalOutputsLayerBuffers;
+ }
+
+ ///
+ public INetworkSelector Select()
+ {
+ return new DefaultNetworkSelector(this);
+ }
+ }
}
diff --git a/Sigma.Core/Data/Datasets/Dataset.cs b/Sigma.Core/Data/Datasets/Dataset.cs
deleted file mode 100644
index d84396f4..00000000
--- a/Sigma.Core/Data/Datasets/Dataset.cs
+++ /dev/null
@@ -1,1048 +0,0 @@
-/*
-MIT License
-
-Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer
-
-For full license see LICENSE in the root directory of this project.
-*/
-
-using System;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.Linq;
-using System.Threading;
-using System.Threading.Tasks;
-using log4net;
-using Sigma.Core.Data.Extractors;
-using Sigma.Core.Handlers;
-using Sigma.Core.MathAbstract;
-using Sigma.Core.Persistence;
-using Sigma.Core.Utils;
-
-namespace Sigma.Core.Data.Datasets
-{
- ///
- /// A default implementation of the IDataset interface.
- /// Provides caching of entire blocks and reader data, partial extraction, unordered extraction, automatic block sizing, smart block loading.
- ///
- [Serializable]
- public class Dataset : IDataset, ISerialisationNotifier
- {
- [NonSerialized]
- private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType);
-
- ///
- /// Automatically size blocks according to estimated data metrics (e.g. physical memory available, record size).
- ///
- public const int BlockSizeAuto = -1;
-
- ///
- /// Assign all available data to the first block (one block fits it all - literally).
- ///
- public const int BlockSizeAll = -2;
-
- ///
- public string Name { get; }
-
- ///
- /// Indicate if this dataset is an online dataset (meaning new data might be added during runtime).
- /// By default, this is assumed to be false, indicating a static dataset.
- /// Note: Data iterators and may perform certain optimisations for static datasets, so set this to false if possible.
- ///
- public bool Online { get; set; } = false;
-
- ///
- public int MaxConcurrentActiveBlocks { get; } = 24; //24 seems like a good number, right?
-
- ///
- public long MaxTotalActiveBlockSizeBytes { get; } = SystemInformationUtils.GetAvailablePhysicalMemoryBytes() / 2; //default to half the available physical memory
-
- ///
- public IReadOnlyCollection ActiveBlockIndices => _activeBlocks.Keys.ToList();
-
- ///
- public int ActiveBlockRegionCount => _activeBlocks.Count;
-
- ///
- public int ActiveIndividualBlockCount { get { return _activeBlocks.Values.Sum(set => set.Count); } }
-
- ///
- public int TargetBlockSizeRecords { get; private set; }
-
- ///
- public string[] SectionNames { get; private set; }
-
- ///
- public long TotalActiveBlockSizeBytes { get; private set; }
-
- public long TotalActiveRecords { get; private set; }
-
- ///
- public int MaxBlocksInCache { get; set; } = int.MaxValue;
-
- ///
- public long MaxBytesInCache { get; set; } = long.MaxValue;
-
- ///
- /// Indicate whether this dataset should cache the raw reader data.
- /// If disabled, only extracted data will be cached and once processed, it might be impossible to retrieve preceding record blocks (reader streams are assumed to be non-seekable).
- ///
- public bool AllowRawReadDataCaching { get; set; } = true;
-
- private readonly Dictionary> _activeBlocks;
- private readonly Dictionary> _cachedBlocks;
- private readonly ICacheProvider _cacheProvider;
-
- private int _lastReadRawDataBlockIndex = -1;
- private long _totalCachedBlockSizeBytes;
- private int _lastAvailableBlockIndex = int.MaxValue;
- private readonly ISet _recordExtractors;
-
- private readonly bool _autoSetBlockSize;
- private bool _autoSetExternalChangeBlockSize;
-
- // TODO fix available blocks semaphore logic
- // the waitones/releases are inconsistent, because blocks aren't always actually allocated, such as null returns are not considered
- [NonSerialized]
- private Semaphore _availableBlocksSemaphore;
- private int _availableBlocksSemaphoreState;
-
- ///
- /// Create a dataset with a certain unique name and the record extractors to use.
- ///
- /// The unique dataset name.
- /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
- public Dataset(string name, params IRecordExtractor[] recordExtractors) : this(name, BlockSizeAuto, recordExtractors)
- {
- }
-
- ///
- /// Create a dataset with a certain unique name, target block size in records and the record extractors to use.
- ///
- /// The unique dataset name.
- /// The target block size for records. May also be or .
- /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
- public Dataset(string name, int blockSizeRecords, params IRecordExtractor[] recordExtractors)
- : this(name, blockSizeRecords, new DiskCacheProvider(SigmaEnvironment.Globals.Get("cache_path") + name), true, recordExtractors)
- {
- }
-
- ///
- /// Create a dataset with a certain unique name, target block size in records, specific cache provider and the record extractors to use.
- ///
- /// The unique dataset name.
- /// The target block size for records. May also be or .
- /// The cache provider to use for caching record blocks and raw reader data.
- /// Indicate whether the cache provider should be flushed (cleared) before use. Only disable if block size and extractors used do not change (otherwise undefined behaviour).
- /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
- public Dataset(string name, int blockSizeRecords, ICacheProvider cacheProvider, bool flushCache = true, params IRecordExtractor[] recordExtractors)
- {
- if (name == null)
- {
- throw new ArgumentNullException(nameof(name));
- }
-
- if (recordExtractors == null)
- {
- throw new ArgumentNullException(nameof(recordExtractors));
- }
-
- if (recordExtractors.Length == 0)
- {
- throw new ArgumentException("Datasets require at least one record extractor, but none were given.");
- }
-
- if (cacheProvider == null)
- {
- throw new ArgumentNullException(nameof(cacheProvider));
- }
-
- switch (blockSizeRecords) {
- case BlockSizeAll:
- //just set to maximum amount of records, extracting returns the maximum available anyway and we can't know the actual availability yet
- TargetBlockSizeRecords = int.MaxValue;
- break;
- case BlockSizeAuto:
- //somewhat temporary guesstimate, should probably expose the individual parameters
- const long estimatedRecordSizeBytes = 1024;
- const double memoryToConsume = 0.2f;
- const long optimalNumberBlocks = 8;
- const int maxBlockSizeRecords = 4096;
- long availableSystemMemory = SystemInformationUtils.GetAvailablePhysicalMemoryBytes();
-
- TargetBlockSizeRecords = Math.Min(maxBlockSizeRecords, (int) (availableSystemMemory * memoryToConsume / estimatedRecordSizeBytes / optimalNumberBlocks));
-
- _autoSetBlockSize = true;
- break;
- default:
- if (blockSizeRecords == 0 || blockSizeRecords < -2)
- {
- throw new ArgumentException($"Block size in records must be either BLOCK_SIZE_ALL, BLOCK_SIZE_AUTO or > 0, but given block size was {blockSizeRecords}.");
- }
- else
- {
- TargetBlockSizeRecords = blockSizeRecords;
- }
- break;
- }
-
- Name = name;
- AnalyseExtractors(recordExtractors);
-
- _cacheProvider = cacheProvider;
- _recordExtractors = new HashSet(recordExtractors);
-
- _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks, MaxConcurrentActiveBlocks);
- _availableBlocksSemaphoreState = MaxConcurrentActiveBlocks;
-
- _activeBlocks = new Dictionary>();
- _cachedBlocks = new Dictionary>();
-
- if (flushCache)
- {
- _logger.Debug($"Flushing all caches for dataset \"{Name}\" as flushCache flag was set...");
-
- InvalidateAndClearCaches();
-
- _logger.Debug($"Done flushing all caches for dataset \"{Name}.\"");
- }
- }
-
- ///
- /// Called before this object is serialised.
- ///
- public void OnSerialising()
- {
- }
-
- ///
- /// Called after this object was serialised.
- ///
- public void OnSerialised()
- {
- }
-
- ///
- /// Called after this object was de-serialised.
- ///
- public void OnDeserialised()
- {
- InvalidateAndClearCaches();
- _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks - ActiveIndividualBlockCount, MaxConcurrentActiveBlocks);
- }
-
- public IDataset[] SplitBlockwise(params int[] parts)
- {
- return SplitBlockwise(this, parts);
- }
-
- public IDataset[] SplitRecordwise(params double[] parts)
- {
- return SplitRecordwise(this, parts);
- }
-
- public bool TrySetBlockSize(int blockSizeRecords)
- {
- if (blockSizeRecords == TargetBlockSizeRecords)
- {
- //nothing to do here
- return true;
- }
-
- if (!_autoSetBlockSize)
- {
- _logger.Debug($"Cannot change block size as block size was not set automatically (attempted to change block size to {blockSizeRecords}.");
-
- return false;
- }
-
- if (_activeBlocks.Count > 0 || _cachedBlocks.Count > 0)
- {
- _logger.Debug($"Cannot change block size as {_activeBlocks.Count + _cachedBlocks.Count} blocks were already fetched and are active or cached.");
-
- return false;
- }
-
- if (_autoSetExternalChangeBlockSize && blockSizeRecords != TargetBlockSizeRecords)
- {
- _logger.Debug($"Cannot change block size to {blockSizeRecords}, block size is incompatible with another external block size change request (other request: {TargetBlockSizeRecords})");
-
- return false;
- }
-
- _autoSetExternalChangeBlockSize = true;
- TargetBlockSizeRecords = blockSizeRecords;
-
- return true;
- }
-
- private void AnalyseExtractors(IEnumerable extractors)
- {
- ISet sectionNames = new HashSet();
-
- int index = 0;
- foreach (IRecordExtractor extractor in extractors)
- {
- if (extractor == null)
- {
- throw new ArgumentNullException($"Extractor at index {index} was null.");
- }
-
- if (extractor.SectionNames == null)
- {
- throw new ArgumentNullException($"Section names field in extractor {extractor} was null (field has to be set by extractor).");
- }
-
- string[] extractorSectionNames = extractor.SectionNames;
-
- foreach (string sectionName in extractorSectionNames)
- {
- if (sectionNames.Contains(sectionName))
- {
- throw new ArgumentException($"Section name collision: duplicate section name {sectionName} detected for extractor {extractor}.");
- }
- else
- {
- sectionNames.Add(sectionName);
- }
- }
-
- index++;
- }
-
- SectionNames = sectionNames.ToArray();
- }
-
- public int GetNumberOfLoadedInactiveCachedBlocks()
- {
- return _cachedBlocks.Values.SelectMany(blockSet => blockSet).Count(block => block.Loaded);
- }
-
- public bool CanFetchBlocksAfter(int blockIndex)
- {
- return blockIndex <= _lastAvailableBlockIndex;
- }
-
- public async Task> FetchBlockAsync(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true)
- {
- //TODO check if block even could be fetched to not waste thread resources if shouldWaitUntilAvailable is false anyway
-
- return await Task.Run(() => FetchBlock(blockIndex, handler, shouldWaitUntilAvailable));
- }
-
- public IDictionary FetchBlock(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true)
- {
- Dictionary block = FetchBlockConstrained(blockIndex, handler);
-
- //block could be fetched directly without violating any constraints, return successfully
- if (block != null)
- {
- if (block.Count == 0)
- {
- throw new InvalidOperationException("Fetched block did not contain any named elements (was empty; is the extractor output correct?).");
- }
-
- RegisterActiveBlock(block, blockIndex, handler);
-
- return block;
- }
- else
- {
- if (blockIndex >= _lastAvailableBlockIndex)
- {
- return null;
- }
-
- if (shouldWaitUntilAvailable)
- {
- _logger.Debug($"Could not directly load block with index {blockIndex} for handler {handler} and shouldWaitUntilAvailable flag is set to true, waiting for available space...");
-
- return FetchBlockWhenAvailable(blockIndex, handler);
- }
- else
- {
- return null;
- }
- }
- }
-
- private void RegisterActiveBlock(Dictionary block, int blockIndex, IComputationHandler handler)
- {
- INDArray firstNamedBlock = block[block.First().Key];
-
- if (IsBlockActive(blockIndex, handler))
- {
- //block already registered as active, nothing to do here
- return;
- }
-
- RecordBlock recordBlock = new RecordBlock(block, blockIndex, firstNamedBlock.Shape[0],
- handler.GetSizeBytes(block.Values.ToArray()), handler)
- { Loaded = true, Active = true };
-
- lock (this)
- {
- TotalActiveBlockSizeBytes += recordBlock.EstimatedSizeBytes;
- TotalActiveRecords += recordBlock.NumberRecords;
-
- if (!_activeBlocks.ContainsKey(blockIndex))
- {
- _activeBlocks.Add(blockIndex, new HashSet());
- }
-
- _activeBlocks[blockIndex].Add(recordBlock);
- }
- }
-
- private void DeregisterActiveBlock(RecordBlock recordBlock)
- {
- if (!IsBlockActive(recordBlock.BlockIndex, recordBlock.Handler))
- {
- //block that should be de-registered is not even registered
- return;
- }
-
- lock (this)
- {
- TotalActiveBlockSizeBytes -= recordBlock.EstimatedSizeBytes;
- TotalActiveRecords -= recordBlock.NumberRecords;
-
- _activeBlocks[recordBlock.BlockIndex].Remove(recordBlock);
-
- if (_activeBlocks[recordBlock.BlockIndex].Count == 0)
- {
- _activeBlocks.Remove(recordBlock.BlockIndex);
- }
- }
- }
-
- private void RegisterCachedBlock(Dictionary block, int blockIndex, IComputationHandler handler, bool keepReference)
- {
- if (IsBlockCached(blockIndex, handler))
- {
- //block's already cached, nothing to do here
- return;
- }
-
- if (!_cachedBlocks.ContainsKey(blockIndex))
- {
- _cachedBlocks.Add(blockIndex, new HashSet());
- }
-
- WeakRecordBlock recordBlock = new WeakRecordBlock(keepReference ? block : null, blockIndex, block.First().Value.Shape[0], handler.GetSizeBytes(block.Values.ToArray()), handler);
-
- recordBlock.Loaded = false;
-
- _cachedBlocks[blockIndex].Add(recordBlock);
- }
-
- ///
- /// Invalidate and clear all caches associated with this dataset.
- /// WARNING: Removing cache entries may cause certain datasets to load much more slowly or even incorrectly.
- /// Legitimate use cases include removing cache entries for old datasets or changing extractors.
- ///
- public void InvalidateAndClearCaches()
- {
- _logger.Debug("Invalidating and clearing all caches...");
-
- _cacheProvider.RemoveAll();
- _cachedBlocks.Clear();
- _totalCachedBlockSizeBytes = 0L;
-
- _logger.Debug("Done invalidating and clearing all caches.");
- }
-
- private Dictionary FetchBlockWhenAvailable(int blockIndex, IComputationHandler handler)
- {
- while (true)
- {
- _logger.Debug($"Attempting to extract block region for request for block index {blockIndex} for handler {handler}, checking if it fits all constraints...");
-
- Dictionary block = FetchBlockConstrained(blockIndex, handler);
-
- //if block != null we could fetch the block successfully without violating any constraints
- if (block != null)
- {
- RegisterActiveBlock(block, blockIndex, handler);
-
- return block;
- }
- else
- {
- //we cannot retrieve any more blocks and shouldn't keep trying
- if (blockIndex >= _lastAvailableBlockIndex)
- {
- return null;
- }
-
- _logger.Debug($"Request for block with index {blockIndex} for handler {handler} was returned to the queue, seems to be violating constraints...");
- }
- }
- }
-
- private Dictionary FetchBlockConstrained(int blockIndex, IComputationHandler handler)
- {
- if (ActiveIndividualBlockCount >= MaxConcurrentActiveBlocks)
- {
- _logger.Debug($"Unable to fetch block due to MaxConcurrentActiveBlocks constraint of {MaxConcurrentActiveBlocks}.");
-
- return null;
- }
-
- Dictionary block = LoadAndExtractBlockWhenAvailable(blockIndex, handler);
-
- //there was nothing to load and extract, most likely end of stream
- if (block == null)
- {
- return null;
- }
-
- long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray());
-
- if (TotalActiveBlockSizeBytes + blockSizeBytes > MaxTotalActiveBlockSizeBytes)
- {
- _logger.Debug($"Unable to keep requested block {blockIndex} for handler {handler} in memory due to MaxTotalActiveBlockSizeBytes constraint of {MaxTotalActiveBlockSizeBytes} bytes (block of size {blockSizeBytes} would exceed constraint by {TotalActiveBlockSizeBytes + blockSizeBytes - MaxTotalActiveBlockSizeBytes} bytes.).");
-
- CacheBlockConstrained(block, blockIndex, handler);
-
- return null;
- }
-
- return block;
- }
-
- private Dictionary LoadAndExtractBlockWhenAvailable(int blockIndex, IComputationHandler handler)
- {
- //this method takes care of
- // - checking whether the index is already loaded and active and then converts it
- // - or checking whether the index is already cached in the right format and loads
- // - or if none of that, loads and extracts from the original extractors
-
- //check whether a block with the same index and format is already active
- if (_activeBlocks.ContainsKey(blockIndex))
- {
- Dictionary block = GetBestMatchedBlockWhenAvailable(_activeBlocks[blockIndex], handler);
-
- if (block != null)
- {
- return block;
- }
- }
-
- //check whether a block with the same index and format is already loaded and cached but not active
- if (_cachedBlocks.ContainsKey(blockIndex))
- {
- Dictionary block = GetBestMatchedBlockWhenAvailable(_cachedBlocks[blockIndex], handler);
-
- if (block != null)
- {
- return block;
- }
- }
-
- lock (_cacheProvider)
- {
- string blockIdentifierInCache = $"extracted.{blockIndex}.{handler.DataType.Identifier}";
-
- //check whether a block of the same index and format is cached in the cache provider
- if (_cacheProvider.IsCached(blockIdentifierInCache))
- {
- Dictionary block = _cacheProvider.Load>(blockIdentifierInCache);
-
- //if its != null we could read it correctly in the right format
- if (block != null)
- {
- //register this cache entry as a properly loaded block in case the cache wasn't flushed and the cache map is outdated
- RegisterCachedBlock(block, blockIndex, handler, keepReference: false);
-
- return block;
- }
- }
- }
-
- //_availableBlocksSemaphore.WaitOne();
- //_availableBlocksSemaphoreState--;
-
- return LoadAndExtractRaw(blockIndex, handler);
- }
-
- private Dictionary GetBestMatchedBlockWhenAvailable(IEnumerable blocks, IComputationHandler handler)
- {
- RecordBlockBase bestMatchedBlock = null;
-
- foreach (RecordBlockBase otherBlock in blocks)
- {
- if (otherBlock.Loaded && handler.CanConvert(otherBlock.FirstNamedBlock, otherBlock.Handler))
- {
- if (handler.IsInterchangeable(otherBlock.Handler))
- {
- //no need to look any further, we already found the perfect match and can return without conversion
- return otherBlock.NamedBlockSections;
- }
-
- bestMatchedBlock = otherBlock;
- }
- }
-
- if (bestMatchedBlock == null)
- {
- return null;
- }
-
- //_availableBlocksSemaphore.WaitOne();
- //_availableBlocksSemaphoreState--;
-
- return ConvertNamedBlocks(bestMatchedBlock.NamedBlockSections, handler);
- }
-
- private static Dictionary ConvertNamedBlocks(Dictionary namedBlockSections, IComputationHandler handler)
- {
- Dictionary convertedNamedBlocks = new Dictionary();
-
- foreach (string name in namedBlockSections.Keys)
- {
- convertedNamedBlocks.Add(name, handler.Convert(namedBlockSections[name], handler));
- }
-
- return convertedNamedBlocks;
- }
-
- private Dictionary LoadAndExtractRaw(int blockIndex, IComputationHandler handler)
- {
- // this cannot run concurrently as cache entries can only be read and written once without wasting resources and / or corrupting cache state
- lock (this)
- {
- if (blockIndex >= _lastReadRawDataBlockIndex)
- {
- object[] lastRawData = null;
-
- for (int tempBlockIndex = _lastReadRawDataBlockIndex + 1; tempBlockIndex <= blockIndex; tempBlockIndex++)
- {
- lastRawData = LoadDirect(tempBlockIndex, handler);
-
- //looks like we couldn't read any more blocks, maybe reached the end of the underlying source streams
- if (lastRawData == null)
- {
- return null;
- }
-
- if (AllowRawReadDataCaching)
- {
- _cacheProvider.Store($"raw.{tempBlockIndex}", lastRawData);
- }
- }
-
- return ExtractDirectFrom(lastRawData, blockIndex, handler);
- }
- else
- {
- if (AllowRawReadDataCaching)
- {
- string cacheIdentifier = $"raw.{blockIndex}";
-
- if (!_cacheProvider.IsCached(cacheIdentifier))
- {
- throw new InvalidOperationException($"Unable to load cached entry for block {blockIndex} for handler {handler}, cache entry does not exist in provider {_cacheProvider}.");
- }
-
- return ExtractDirectFrom(_cacheProvider.Load(cacheIdentifier), blockIndex, handler);
- }
- else
- {
- throw new InvalidOperationException($"Cannot load and extract raw block with index {blockIndex} because AllowRawReadDataCaching is set to false and last read position is at {_lastReadRawDataBlockIndex}.");
- }
- }
- }
- }
-
- private object[] LoadDirect(int blockIndex, IComputationHandler handler)
- {
- IList rawDataPerExtractor = new List();
-
- PrepareExtractors();
-
- foreach (IRecordExtractor extractor in _recordExtractors)
- {
- object data;
-
- lock (extractor.Reader)
- {
- data = extractor.Reader.Read(TargetBlockSizeRecords);
- }
-
- //check if block reader could read anything, if not, return null
- if (data == null)
- {
- _lastAvailableBlockIndex = blockIndex - 1;
-
- _logger.Debug($"Cannot load block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached.");
-
- return null;
- }
-
- rawDataPerExtractor.Add(data);
- }
-
- if (blockIndex > _lastReadRawDataBlockIndex)
- {
- _lastReadRawDataBlockIndex = blockIndex;
- }
-
- return rawDataPerExtractor.ToArray();
- }
-
- private Dictionary ExtractDirectFrom(object[] data, int blockIndex, IComputationHandler handler)
- {
- Dictionary namedBlocks = new Dictionary();
-
- ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare, "preparing extractors for dataset \"" + Name + "\"", indeterminate: true);
-
- PrepareExtractors();
-
- SigmaEnvironment.TaskManager.EndTask(prepareTask);
-
- ITaskObserver extractTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Extract, $"extracting block {blockIndex} for dataset \"{Name}\"", indeterminate: true);
-
- int extractorIndex = 0;
- foreach (IRecordExtractor extractor in _recordExtractors)
- {
- _logger.Debug($"Extracting hierarchically from extractor {extractor} at index {extractorIndex}...");
-
- Dictionary subNamedBlock = extractor.ExtractHierarchicalFrom(data[extractorIndex++], TargetBlockSizeRecords, handler);
-
- //check if block size is 0, indicating we reached the end of the stream
- if (subNamedBlock == null)
- {
- _lastAvailableBlockIndex = blockIndex - 1;
-
- _logger.Debug($"Cannot extract block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached.");
-
- SigmaEnvironment.TaskManager.CancelTask(extractTask);
-
- return null;
- }
-
- foreach (string name in subNamedBlock.Keys)
- {
- if (namedBlocks.ContainsKey(name))
- {
- SigmaEnvironment.TaskManager.CancelTask(extractTask);
-
- throw new ArgumentException($"Section name collision: {name} is already used by another extractor, current extractor {extractor} cannot use it again.");
- }
- else
- {
- namedBlocks.Add(name, subNamedBlock[name]);
- }
- }
- }
-
- SigmaEnvironment.TaskManager.EndTask(extractTask);
-
- return namedBlocks;
- }
-
- public void FreeBlock(int blockIndex, IComputationHandler handler)
- {
- if (!_activeBlocks.ContainsKey(blockIndex))
- {
- _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active.");
-
- return;
- }
-
- foreach (RecordBlock block in _activeBlocks[blockIndex])
- {
- if (ReferenceEquals(block.Handler, handler))
- {
- _logger.Debug($"Freeing block with index {blockIndex} for handler {handler}...");
-
- CacheBlockConstrained(block.NamedBlockSections, blockIndex, handler);
-
- DeregisterActiveBlock(block);
-
- //_availableBlocksSemaphore.Release();
- //_availableBlocksSemaphoreState++;
-
- _logger.Debug($"Done freeing block with index {blockIndex} for handler {handler}.");
-
- return;
- }
- }
-
- _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active.");
- }
-
- private void CacheBlockConstrained(Dictionary block, int blockIndex, IComputationHandler handler)
- {
- if (_cachedBlocks.ContainsKey(blockIndex))
- {
- foreach (WeakRecordBlock cachedBlock in _cachedBlocks[blockIndex])
- {
- //check if block of the same type and size is already cached, if so, return, because there is no need to cache again
- if (cachedBlock.BlockIndex == blockIndex && cachedBlock.Handler.IsInterchangeable(handler) && block.First().Value.Shape[0] == cachedBlock.NumberRecords)
- {
- _logger.Debug($"Skipping cache request of block {blockIndex} for handler {handler} because interchangeable block of same index, format and size is already cached.");
-
- return;
- }
- }
- }
-
- long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray());
-
- if (_cachedBlocks.Count >= MaxBlocksInCache)
- {
- _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBlocksInCache constraint of {MaxBlocksInCache}.");
-
- return;
- }
-
- if (blockSizeBytes + _totalCachedBlockSizeBytes >= MaxBytesInCache)
- {
- _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBytesInCache constraint of {MaxBytesInCache} bytes (block of size {blockSizeBytes} would exceed constraint by {_totalCachedBlockSizeBytes + blockSizeBytes - MaxBytesInCache} bytes).");
-
- return;
- }
-
- string cacheIdentifier = $"extracted.{blockIndex}.{handler.DataType.Identifier}";
-
- _cacheProvider.Store(cacheIdentifier, block);
-
- bool keepReference = TotalActiveBlockSizeBytes + blockSizeBytes < MaxTotalActiveBlockSizeBytes;
-
- RegisterCachedBlock(block, blockIndex, handler, keepReference);
-
- _totalCachedBlockSizeBytes += blockSizeBytes;
- }
-
- private void PrepareExtractors()
- {
- foreach (IRecordExtractor extractor in _recordExtractors)
- {
- lock (extractor)
- {
- extractor.Prepare();
- }
- }
- }
-
- public long GetBlockSizeBytes(int blockIndex, IComputationHandler handler)
- {
- if (!_activeBlocks.ContainsKey(blockIndex))
- {
- return -1L;
- }
-
- foreach (RecordBlock block in _activeBlocks[blockIndex])
- {
- if (ReferenceEquals(block.Handler, handler))
- {
- return block.EstimatedSizeBytes;
- }
- }
-
- return -1L;
- }
-
- public bool IsBlockActive(int blockIndex)
- {
- return _activeBlocks.ContainsKey(blockIndex);
- }
-
- public bool IsBlockActive(int blockIndex, IComputationHandler handler)
- {
- if (!_activeBlocks.ContainsKey(blockIndex))
- {
- return false;
- }
-
- foreach (RecordBlock block in _activeBlocks[blockIndex])
- {
- if (ReferenceEquals(block.Handler, handler))
- {
- return true;
- }
- }
-
- return false;
- }
-
- private bool IsBlockCached(int blockIndex, IComputationHandler handler)
- {
- if (!_cachedBlocks.ContainsKey(blockIndex))
- {
- return false;
- }
-
- foreach (WeakRecordBlock block in _cachedBlocks[blockIndex])
- {
- if (ReferenceEquals(block.Handler, handler))
- {
- return true;
- }
- }
-
- return false;
- }
-
- public void Dispose()
- {
- foreach (IRecordExtractor extractor in _recordExtractors)
- {
- extractor.Dispose();
- extractor.Reader?.Dispose();
- }
-
- _cacheProvider.Dispose();
- }
-
- public static IDataset[] SplitBlockwise(IDataset dataset, params int[] parts)
- {
- if (parts.Length == 0)
- {
- throw new ArgumentException("Parts cannot be an empty collection.");
- }
-
- int splitInterval = parts.Sum();
- int lastEnd = 0;
- IDataset[] slices = new IDataset[parts.Length];
-
- for (int i = 0; i < parts.Length; i++)
- {
- slices[i] = new DatasetBlockwiseSlice(dataset, lastEnd, lastEnd + parts[i] - 1, splitInterval);
- lastEnd += parts[i];
- }
-
- return slices;
- }
-
- public static IDataset[] SplitRecordwise(IDataset dataset, params double[] parts)
- {
- if (parts.Length == 0)
- {
- throw new ArgumentException("Percentages cannot be an empty collection.");
- }
-
- if (parts.Sum() > 1.0)
- {
- throw new ArgumentException($"Percentages sum cannot be > 1.0, but parts sum was {parts.Sum()}.");
- }
-
- IDataset[] slices = new IDataset[parts.Length];
-
- double lastOffset = 0.0;
-
- for (int i = 0; i < slices.Length; i++)
- {
- slices[i] = new DatasetRecordwiseSlice(dataset, lastOffset, parts[i]);
-
- lastOffset += parts[i];
- }
-
- return slices;
- }
-
- internal abstract class RecordBlockBase
- {
- internal abstract Dictionary NamedBlockSections { get; set; }
- internal abstract INDArray FirstNamedBlock { get; set; }
- internal abstract bool Loaded { get; set; }
-
- internal IComputationHandler Handler;
- internal bool Active;
- internal int BlockIndex;
- internal long NumberRecords;
- internal long EstimatedSizeBytes;
- }
-
- internal class RecordBlock : RecordBlockBase
- {
- internal sealed override Dictionary NamedBlockSections { get; set; }
- internal sealed override INDArray FirstNamedBlock { get; set; }
- internal override bool Loaded { get; set; }
-
- public RecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler)
- {
- NamedBlockSections = namedBlockSections;
- BlockIndex = blockIndex;
- NumberRecords = numberRecords;
- EstimatedSizeBytes = estimatedSizeBytes;
- Handler = handler;
-
- //record blocks internal block can be null
- if (namedBlockSections != null)
- {
- FirstNamedBlock = namedBlockSections[namedBlockSections.First().Key];
- }
- }
- }
-
- internal class WeakRecordBlock : RecordBlockBase
- {
- internal override Dictionary NamedBlockSections
- {
- get
- {
- Dictionary target;
-
- return _namedBlockSections.TryGetTarget(out target) ? target : null;
- }
- set
- {
- _namedBlockSections.SetTarget(value);
- }
- }
-
- internal override INDArray FirstNamedBlock
- {
- get
- {
- INDArray target;
-
- return _firstNamedBlock.TryGetTarget(out target) ? target : null;
- }
- set
- {
- _firstNamedBlock.SetTarget(value);
- }
- }
-
- internal override bool Loaded
- {
- get
- {
- Dictionary target;
-
- return _namedBlockSections.TryGetTarget(out target);
- }
- set
- {
- }
- }
-
- private readonly WeakReference> _namedBlockSections;
- private readonly WeakReference _firstNamedBlock;
-
- public WeakRecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler)
- {
- _namedBlockSections = new WeakReference>(namedBlockSections);
- BlockIndex = blockIndex;
- NumberRecords = numberRecords;
- EstimatedSizeBytes = estimatedSizeBytes;
- Handler = handler;
-
- //record blocks internal block can be null
- if (namedBlockSections != null)
- {
- _firstNamedBlock = new WeakReference(namedBlockSections[namedBlockSections.First().Key]);
- }
- }
- }
-
- public override string ToString()
- {
- return $"dataset \"{Name}\"";
- }
- }
-}
diff --git a/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs b/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs
index 22797aa7..31509525 100644
--- a/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs
+++ b/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs
@@ -50,7 +50,7 @@ public long MaxBytesInCache
}
public string[] SectionNames => UnderlyingDataset.SectionNames;
public IReadOnlyCollection ActiveBlockIndices => UnderlyingDataset.ActiveBlockIndices;
- public int ActiveIndividualBlockCount => UnderlyingDataset.ActiveIndividualBlockCount;
+ public int ActiveIndividualBlockRegionCount => UnderlyingDataset.ActiveIndividualBlockRegionCount;
public int ActiveBlockRegionCount => UnderlyingDataset.ActiveBlockRegionCount;
///
@@ -112,12 +112,12 @@ protected int MapToUnderlyingIndex(int blockIndex)
public IDataset[] SplitBlockwise(params int[] parts)
{
- return Dataset.SplitBlockwise(this, parts);
+ return ExtractedDataset.SplitBlockwise(this, parts);
}
public IDataset[] SplitRecordwise(params double[] parts)
{
- return Dataset.SplitRecordwise(this, parts);
+ return ExtractedDataset.SplitRecordwise(this, parts);
}
public bool TrySetBlockSize(int blockSizeRecords)
diff --git a/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs b/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs
index baf2fe84..f8714d6b 100644
--- a/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs
+++ b/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs
@@ -48,7 +48,7 @@ public long MaxBytesInCache
}
public string[] SectionNames => UnderlyingDataset.SectionNames;
public IReadOnlyCollection ActiveBlockIndices => UnderlyingDataset.ActiveBlockIndices;
- public int ActiveIndividualBlockCount => UnderlyingDataset.ActiveIndividualBlockCount;
+ public int ActiveIndividualBlockRegionCount => UnderlyingDataset.ActiveIndividualBlockRegionCount;
public int ActiveBlockRegionCount => UnderlyingDataset.ActiveBlockRegionCount;
///
@@ -89,12 +89,12 @@ public DatasetRecordwiseSlice(IDataset underlyingDataset, double shareOffset, do
public IDataset[] SplitBlockwise(params int[] parts)
{
- return Dataset.SplitBlockwise(this, parts);
+ return ExtractedDataset.SplitBlockwise(this, parts);
}
public IDataset[] SplitRecordwise(params double[] parts)
{
- return Dataset.SplitRecordwise(this, parts);
+ return ExtractedDataset.SplitRecordwise(this, parts);
}
public bool TrySetBlockSize(int blockSizeRecords)
diff --git a/Sigma.Core/Data/Datasets/ExtractedDataset.cs b/Sigma.Core/Data/Datasets/ExtractedDataset.cs
new file mode 100644
index 00000000..9ca598a9
--- /dev/null
+++ b/Sigma.Core/Data/Datasets/ExtractedDataset.cs
@@ -0,0 +1,1069 @@
+/*
+MIT License
+
+Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer
+
+For full license see LICENSE in the root directory of this project.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using log4net;
+using Sigma.Core.Data.Extractors;
+using Sigma.Core.Handlers;
+using Sigma.Core.MathAbstract;
+using Sigma.Core.Persistence;
+using Sigma.Core.Utils;
+
+namespace Sigma.Core.Data.Datasets
+{
+ ///
+ /// A default implementation of the IDataset interface.
+ /// Provides caching of entire blocks and reader data, partial extraction, unordered extraction, automatic block sizing, smart block loading.
+ ///
+ [Serializable]
+ public class ExtractedDataset : IDataset, ISerialisationNotifier
+ {
+ [NonSerialized]
+ private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType);
+
+ ///
+ /// Automatically size blocks according to estimated data metrics (e.g. physical memory available, record size).
+ ///
+ public const int BlockSizeAuto = -1;
+
+ ///
+ /// Assign all available data to the first block (one block fits it all - literally).
+ ///
+ public const int BlockSizeAll = -2;
+
+ ///
+ public string Name { get; }
+
+ ///
+ /// Indicate if this dataset is an online dataset (meaning new data might be added during runtime).
+ /// By default, this is assumed to be false, indicating a static dataset.
+ /// Note: Data iterators and may perform certain optimisations for static datasets, so set this to false if possible.
+ ///
+ public bool Online { get; set; } = false;
+
+ ///
+ public int MaxConcurrentActiveBlocks { get; } = 24; //24 seems like a good number, right?
+
+ ///
+ public long MaxTotalActiveBlockSizeBytes { get; } = SystemInformationUtils.GetAvailablePhysicalMemoryBytes() / 2; //default to half the available physical memory
+
+ ///
+ public IReadOnlyCollection ActiveBlockIndices => _activeBlocks.Keys.ToList();
+
+ ///
+ public int ActiveBlockRegionCount => _activeBlocks.Count;
+
+ ///
+ public int ActiveIndividualBlockRegionCount { get { return _activeBlocks.Values.Sum(set => set.Count); } }
+
+ ///
+ public int TargetBlockSizeRecords { get; private set; }
+
+ ///
+ public string[] SectionNames { get; private set; }
+
+ ///
+ public long TotalActiveBlockSizeBytes { get; private set; }
+
+ public long TotalActiveRecords { get; private set; }
+
+ ///
+ public int MaxBlocksInCache { get; set; } = int.MaxValue;
+
+ ///
+ public long MaxBytesInCache { get; set; } = long.MaxValue;
+
+ ///
+ /// Indicate whether this dataset should cache the raw reader data.
+ /// If disabled, only extracted data will be cached and once processed, it might be impossible to retrieve preceding record blocks (reader streams are assumed to be non-seekable).
+ ///
+ public bool AllowRawReadDataCaching { get; set; } = true;
+
+ private readonly Dictionary> _activeBlocks;
+ private readonly Dictionary> _cachedBlocks;
+ private readonly ICacheProvider _cacheProvider;
+
+ private int _lastReadRawDataBlockIndex = -1;
+ private long _totalCachedBlockSizeBytes;
+ private int _lastAvailableBlockIndex = int.MaxValue;
+ private readonly ISet _recordExtractors;
+
+ private readonly bool _autoSetBlockSize;
+ private bool _autoSetExternalChangeBlockSize;
+
+ // TODO fix available blocks semaphore logic
+ // the waitones/releases are inconsistent, because blocks aren't always actually allocated, such as null returns are not considered
+ [NonSerialized]
+ private Semaphore _availableBlocksSemaphore;
+ private int _availableBlocksSemaphoreState;
+
+ ///
+ /// Create a dataset with a certain unique name and the record extractors to use.
+ ///
+ /// The unique dataset name.
+ /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
+ public ExtractedDataset(string name, params IRecordExtractor[] recordExtractors) : this(name, BlockSizeAuto, recordExtractors)
+ {
+ }
+
+ ///
+ /// Create a dataset with a certain unique name, target block size in records and the record extractors to use.
+ ///
+ /// The unique dataset name.
+ /// The target block size for records. May also be or .
+ /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
+ public ExtractedDataset(string name, int blockSizeRecords, params IRecordExtractor[] recordExtractors)
+ : this(name, blockSizeRecords, true, recordExtractors)
+ {
+ }
+
+ ///
+ /// Create a dataset with a certain unique name, target block size in records and the record extractors to use.
+ ///
+ /// The unique dataset name.
+ /// The target block size for records. May also be or .
+ /// Indicate whether the cache provider should be flushed (cleared) before use. Only disable if block size and extractors used do not change (otherwise undefined behaviour).
+ /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
+ public ExtractedDataset(string name, int blockSizeRecords, bool flushCache, params IRecordExtractor[] recordExtractors)
+ : this(name, blockSizeRecords, new DiskCacheProvider(SigmaEnvironment.Globals.Get("cache_path") + name), true, recordExtractors)
+ {
+ }
+
+ ///
+ /// Create a dataset with a certain unique name, target block size in records, specific cache provider and the record extractors to use.
+ ///
+ /// The unique dataset name.
+ /// The target block size for records. May also be or .
+ /// The cache provider to use for caching record blocks and raw reader data.
+ /// Indicate whether the cache provider should be flushed (cleared) before use. Only disable if block size and extractors used do not change (otherwise undefined behaviour).
+ /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks.
+ public ExtractedDataset(string name, int blockSizeRecords, ICacheProvider cacheProvider, bool flushCache = true, params IRecordExtractor[] recordExtractors)
+ {
+ if (name == null)
+ {
+ throw new ArgumentNullException(nameof(name));
+ }
+
+ if (recordExtractors == null)
+ {
+ throw new ArgumentNullException(nameof(recordExtractors));
+ }
+
+ if (recordExtractors.Length == 0)
+ {
+ throw new ArgumentException("Datasets require at least one record extractor, but none were given.");
+ }
+
+ if (cacheProvider == null)
+ {
+ throw new ArgumentNullException(nameof(cacheProvider));
+ }
+
+ switch (blockSizeRecords)
+ {
+ case BlockSizeAll:
+ //just set to maximum amount of records, extracting returns the maximum available anyway and we can't know the actual availability yet
+ TargetBlockSizeRecords = int.MaxValue;
+ break;
+ case BlockSizeAuto:
+ //somewhat temporary guesstimate, should probably expose the individual parameters
+ const long estimatedRecordSizeBytes = 1024;
+ const double memoryToConsume = 0.2f;
+ const long optimalNumberBlocks = 8;
+ const int maxBlockSizeRecords = 4096;
+ long availableSystemMemory = SystemInformationUtils.GetAvailablePhysicalMemoryBytes();
+
+ TargetBlockSizeRecords = Math.Min(maxBlockSizeRecords, (int)(availableSystemMemory * memoryToConsume / estimatedRecordSizeBytes / optimalNumberBlocks));
+
+ _autoSetBlockSize = true;
+ break;
+ default:
+ if (blockSizeRecords == 0 || blockSizeRecords < -2)
+ {
+ throw new ArgumentException($"Block size in records must be either BLOCK_SIZE_ALL, BLOCK_SIZE_AUTO or > 0, but given block size was {blockSizeRecords}.");
+ }
+ else
+ {
+ TargetBlockSizeRecords = blockSizeRecords;
+ }
+ break;
+ }
+
+ Name = name;
+ AnalyseExtractors(recordExtractors);
+
+ _cacheProvider = cacheProvider;
+ _recordExtractors = new HashSet(recordExtractors);
+
+ _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks, MaxConcurrentActiveBlocks);
+ _availableBlocksSemaphoreState = MaxConcurrentActiveBlocks;
+
+ _activeBlocks = new Dictionary>();
+ _cachedBlocks = new Dictionary>();
+
+ if (flushCache)
+ {
+ _logger.Debug($"Flushing all caches for dataset \"{Name}\" as flushCache flag was set...");
+
+ InvalidateAndClearCaches();
+
+ _logger.Debug($"Done flushing all caches for dataset \"{Name}.\"");
+ }
+ }
+
+ ///
+ /// Called before this object is serialised.
+ ///
+ public void OnSerialising()
+ {
+ }
+
+ ///
+ /// Called after this object was serialised.
+ ///
+ public void OnSerialised()
+ {
+ }
+
+ ///
+ /// Called after this object was de-serialised.
+ ///
+ public void OnDeserialised()
+ {
+ InvalidateAndClearCaches();
+ _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks - ActiveIndividualBlockRegionCount, MaxConcurrentActiveBlocks);
+ }
+
+ public IDataset[] SplitBlockwise(params int[] parts)
+ {
+ return SplitBlockwise(this, parts);
+ }
+
+ public IDataset[] SplitRecordwise(params double[] parts)
+ {
+ return SplitRecordwise(this, parts);
+ }
+
+ public bool TrySetBlockSize(int blockSizeRecords)
+ {
+ if (blockSizeRecords == TargetBlockSizeRecords)
+ {
+ //nothing to do here
+ return true;
+ }
+
+ if (!_autoSetBlockSize)
+ {
+ _logger.Debug($"Cannot change block size as block size was not set automatically (attempted to change block size to {blockSizeRecords}.");
+
+ return false;
+ }
+
+ if (_activeBlocks.Count > 0 || _cachedBlocks.Count > 0)
+ {
+ _logger.Debug($"Cannot change block size as {_activeBlocks.Count + _cachedBlocks.Count} blocks were already fetched and are active or cached.");
+
+ return false;
+ }
+
+ if (_autoSetExternalChangeBlockSize && blockSizeRecords != TargetBlockSizeRecords)
+ {
+ _logger.Debug($"Cannot change block size to {blockSizeRecords}, block size is incompatible with another external block size change request (other request: {TargetBlockSizeRecords})");
+
+ return false;
+ }
+
+ _autoSetExternalChangeBlockSize = true;
+ TargetBlockSizeRecords = blockSizeRecords;
+
+ return true;
+ }
+
+ private void AnalyseExtractors(IEnumerable extractors)
+ {
+ ISet sectionNames = new HashSet();
+
+ int index = 0;
+ foreach (IRecordExtractor extractor in extractors)
+ {
+ if (extractor == null)
+ {
+ throw new ArgumentNullException($"Extractor at index {index} was null.");
+ }
+
+ if (extractor.SectionNames == null)
+ {
+ throw new ArgumentNullException($"Section names field in extractor {extractor} was null (field has to be set by extractor).");
+ }
+
+ string[] extractorSectionNames = extractor.SectionNames;
+
+ foreach (string sectionName in extractorSectionNames)
+ {
+ if (sectionNames.Contains(sectionName))
+ {
+ throw new ArgumentException($"Section name collision: duplicate section name {sectionName} detected for extractor {extractor}.");
+ }
+ else
+ {
+ sectionNames.Add(sectionName);
+ }
+ }
+
+ index++;
+ }
+
+ SectionNames = sectionNames.ToArray();
+ }
+
+ public int GetNumberOfLoadedInactiveCachedBlocks()
+ {
+ return _cachedBlocks.Values.SelectMany(blockSet => blockSet).Count(block => block.Loaded);
+ }
+
+ public bool CanFetchBlocksAfter(int blockIndex)
+ {
+ return blockIndex <= _lastAvailableBlockIndex;
+ }
+
+ public async Task> FetchBlockAsync(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true)
+ {
+ //TODO check if block even could be fetched to not waste thread resources if shouldWaitUntilAvailable is false anyway
+
+ return await Task.Run(() => FetchBlock(blockIndex, handler, shouldWaitUntilAvailable));
+ }
+
+ public IDictionary FetchBlock(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true)
+ {
+ Dictionary block = FetchBlockConstrained(blockIndex, handler);
+
+ //block could be fetched directly without violating any constraints, return successfully
+ if (block != null)
+ {
+ if (block.Count == 0)
+ {
+ throw new InvalidOperationException("Fetched block did not contain any named elements (was empty; is the extractor output correct?).");
+ }
+
+ RegisterActiveBlock(block, blockIndex, handler);
+
+ return block;
+ }
+ else
+ {
+ if (blockIndex >= _lastAvailableBlockIndex)
+ {
+ return null;
+ }
+
+ if (shouldWaitUntilAvailable)
+ {
+ _logger.Debug($"Could not directly load block with index {blockIndex} for handler {handler} and shouldWaitUntilAvailable flag is set to true, waiting for available space...");
+
+ return FetchBlockWhenAvailable(blockIndex, handler);
+ }
+ else
+ {
+ return null;
+ }
+ }
+ }
+
+ private void RegisterActiveBlock(Dictionary block, int blockIndex, IComputationHandler handler)
+ {
+ INDArray firstNamedBlock = block[block.First().Key];
+
+ if (IsBlockActive(blockIndex, handler))
+ {
+ //block already registered as active, nothing to do here
+ return;
+ }
+
+ RecordBlock recordBlock = new RecordBlock(block, blockIndex, firstNamedBlock.Shape[0], handler.GetSizeBytes(block.Values.ToArray()), handler)
+ { Loaded = true, Active = true };
+
+ lock (_activeBlocks)
+ {
+ TotalActiveBlockSizeBytes += recordBlock.EstimatedSizeBytes;
+ TotalActiveRecords += recordBlock.NumberRecords;
+
+ if (!_activeBlocks.ContainsKey(blockIndex))
+ {
+ _activeBlocks.Add(blockIndex, new HashSet());
+ }
+
+ _activeBlocks[blockIndex].Add(recordBlock);
+ }
+ }
+
+ private void DeregisterActiveBlock(RecordBlock recordBlock)
+ {
+ if (!IsBlockActive(recordBlock.BlockIndex, recordBlock.Handler))
+ {
+ //block that should be de-registered is not even registered
+ return;
+ }
+
+ lock (this)
+ {
+ TotalActiveBlockSizeBytes -= recordBlock.EstimatedSizeBytes;
+ TotalActiveRecords -= recordBlock.NumberRecords;
+
+ _activeBlocks[recordBlock.BlockIndex].Remove(recordBlock);
+
+ if (_activeBlocks[recordBlock.BlockIndex].Count == 0)
+ {
+ _activeBlocks.Remove(recordBlock.BlockIndex);
+ }
+ }
+ }
+
+ private void RegisterCachedBlock(Dictionary block, int blockIndex, IComputationHandler handler, bool keepReference)
+ {
+ if (IsBlockCached(blockIndex, handler))
+ {
+ //block's already cached, nothing to do here
+ return;
+ }
+
+ if (!_cachedBlocks.ContainsKey(blockIndex))
+ {
+ _cachedBlocks.Add(blockIndex, new HashSet());
+ }
+
+ WeakRecordBlock recordBlock = new WeakRecordBlock(keepReference ? block : null, blockIndex, block.First().Value.Shape[0], handler.GetSizeBytes(block.Values.ToArray()), handler);
+
+ recordBlock.Loaded = false;
+
+ _cachedBlocks[blockIndex].Add(recordBlock);
+ }
+
+ ///
+ /// Invalidate and clear all caches associated with this dataset.
+ /// WARNING: Removing cache entries may cause certain datasets to load much more slowly or even incorrectly.
+ /// Legitimate use cases include removing cache entries for old datasets or changing extractors.
+ ///
+ public void InvalidateAndClearCaches()
+ {
+ _logger.Debug("Invalidating and clearing all caches...");
+
+ _cacheProvider.RemoveAll();
+ _cachedBlocks.Clear();
+ _totalCachedBlockSizeBytes = 0L;
+
+ _logger.Debug("Done invalidating and clearing all caches.");
+ }
+
+ private Dictionary FetchBlockWhenAvailable(int blockIndex, IComputationHandler handler)
+ {
+ while (true)
+ {
+ _logger.Debug($"Attempting to extract block region for request for block index {blockIndex} for handler {handler}, checking if it fits all constraints...");
+
+ Dictionary block = FetchBlockConstrained(blockIndex, handler);
+
+ //if block != null we could fetch the block successfully without violating any constraints
+ if (block != null)
+ {
+ RegisterActiveBlock(block, blockIndex, handler);
+
+ return block;
+ }
+ else
+ {
+ //we cannot retrieve any more blocks and shouldn't keep trying
+ if (blockIndex >= _lastAvailableBlockIndex)
+ {
+ return null;
+ }
+
+ _logger.Debug($"Request for block with index {blockIndex} for handler {handler} was returned to the queue, seems to be violating constraints...");
+ }
+ }
+ }
+
+ private Dictionary FetchBlockConstrained(int blockIndex, IComputationHandler handler)
+ {
+ if (ActiveIndividualBlockRegionCount >= MaxConcurrentActiveBlocks)
+ {
+ _logger.Debug($"Unable to fetch block due to MaxConcurrentActiveBlocks constraint of {MaxConcurrentActiveBlocks}.");
+
+ return null;
+ }
+
+ Dictionary block = LoadAndExtractBlockWhenAvailable(blockIndex, handler);
+
+ //there was nothing to load and extract, most likely end of stream
+ if (block == null)
+ {
+ return null;
+ }
+
+ long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray());
+
+ if (TotalActiveBlockSizeBytes + blockSizeBytes > MaxTotalActiveBlockSizeBytes)
+ {
+ _logger.Debug($"Unable to keep requested block {blockIndex} for handler {handler} in memory due to MaxTotalActiveBlockSizeBytes constraint of {MaxTotalActiveBlockSizeBytes} bytes (block of size {blockSizeBytes} would exceed constraint by {TotalActiveBlockSizeBytes + blockSizeBytes - MaxTotalActiveBlockSizeBytes} bytes.).");
+
+ CacheBlockConstrained(block, blockIndex, handler);
+
+ return null;
+ }
+
+ return block;
+ }
+
+ private Dictionary LoadAndExtractBlockWhenAvailable(int blockIndex, IComputationHandler handler)
+ {
+ //this method takes care of
+ // - checking whether the index is already loaded and active and then converts it
+ // - or checking whether the index is already cached in the right format and loads
+ // - or if none of that, loads and extracts from the original extractors
+
+ //check whether a block with the same index and format is already active
+ if (_activeBlocks.ContainsKey(blockIndex))
+ {
+ Dictionary block = GetBestMatchedBlockWhenAvailable(_activeBlocks[blockIndex], handler);
+
+ if (block != null)
+ {
+ return block;
+ }
+ }
+
+ //check whether a block with the same index and format is already loaded and cached but not active
+ if (_cachedBlocks.ContainsKey(blockIndex))
+ {
+ Dictionary block = GetBestMatchedBlockWhenAvailable(_cachedBlocks[blockIndex], handler);
+
+ if (block != null)
+ {
+ return block;
+ }
+ }
+
+ lock (_cacheProvider)
+ {
+ string blockIdentifierInCache = $"extracted.{blockIndex}.{handler.DataType.Identifier}";
+
+ //check whether a block of the same index and format is cached in the cache provider
+ if (_cacheProvider.IsCached(blockIdentifierInCache))
+ {
+ Dictionary block = _cacheProvider.Load>(blockIdentifierInCache);
+
+ //if its != null we could read it correctly in the right format
+ if (block != null)
+ {
+ //register this cache entry as a properly loaded block in case the cache wasn't flushed and the cache map is outdated
+ RegisterCachedBlock(block, blockIndex, handler, keepReference: false);
+
+ return block;
+ }
+ }
+ }
+
+ //_availableBlocksSemaphore.WaitOne();
+ //_availableBlocksSemaphoreState--;
+
+ return LoadAndExtractRaw(blockIndex, handler);
+ }
+
+ private Dictionary GetBestMatchedBlockWhenAvailable(IEnumerable blocks, IComputationHandler handler)
+ {
+ RecordBlockBase bestMatchedBlock = null;
+
+ foreach (RecordBlockBase otherBlock in blocks)
+ {
+ if (otherBlock.Loaded && handler.CanConvert(otherBlock.FirstNamedBlock, otherBlock.Handler))
+ {
+ if (handler.IsInterchangeable(otherBlock.Handler))
+ {
+ //no need to look any further, we already found the perfect match and can return without conversion
+ return otherBlock.NamedBlockSections;
+ }
+
+ bestMatchedBlock = otherBlock;
+ }
+ }
+
+ if (bestMatchedBlock == null)
+ {
+ return null;
+ }
+
+ //_availableBlocksSemaphore.WaitOne();
+ //_availableBlocksSemaphoreState--;
+
+ return ConvertNamedBlocks(bestMatchedBlock.NamedBlockSections, handler);
+ }
+
+ private static Dictionary ConvertNamedBlocks(Dictionary namedBlockSections, IComputationHandler handler)
+ {
+ Dictionary convertedNamedBlocks = new Dictionary();
+
+ foreach (string name in namedBlockSections.Keys)
+ {
+ convertedNamedBlocks.Add(name, handler.Convert(namedBlockSections[name], handler));
+ }
+
+ return convertedNamedBlocks;
+ }
+
+ private Dictionary LoadAndExtractRaw(int blockIndex, IComputationHandler handler)
+ {
+ // this cannot run concurrently as cache entries can only be read and written once without wasting resources and / or corrupting cache state
+ lock (this)
+ {
+ if (blockIndex > _lastReadRawDataBlockIndex)
+ {
+ object[] lastRawData = null;
+
+ for (int tempBlockIndex = _lastReadRawDataBlockIndex + 1; tempBlockIndex <= blockIndex; tempBlockIndex++)
+ {
+ lastRawData = LoadDirect(tempBlockIndex, handler);
+
+ //looks like we couldn't read any more blocks, maybe reached the end of the underlying source streams
+ if (lastRawData == null)
+ {
+ return null;
+ }
+
+ if (AllowRawReadDataCaching)
+ {
+ _cacheProvider.Store($"raw.{tempBlockIndex}", lastRawData);
+ }
+ }
+
+ return ExtractDirectFrom(lastRawData, blockIndex, handler);
+ }
+ else
+ {
+ if (AllowRawReadDataCaching)
+ {
+ string cacheIdentifier = $"raw.{blockIndex}";
+
+ if (!_cacheProvider.IsCached(cacheIdentifier))
+ {
+ throw new InvalidOperationException($"Unable to load cached entry for block {blockIndex} for handler {handler}, cache entry does not exist in provider {_cacheProvider}.");
+ }
+
+ return ExtractDirectFrom(_cacheProvider.Load(cacheIdentifier), blockIndex, handler);
+ }
+ else
+ {
+ throw new InvalidOperationException($"Cannot load and extract raw block with index {blockIndex} because AllowRawReadDataCaching is set to false and last read position is at {_lastReadRawDataBlockIndex}.");
+ }
+ }
+ }
+ }
+
+ private object[] LoadDirect(int blockIndex, IComputationHandler handler)
+ {
+ IList rawDataPerExtractor = new List();
+
+ PrepareExtractors();
+
+ foreach (IRecordExtractor extractor in _recordExtractors)
+ {
+ object data;
+
+ lock (extractor.Reader)
+ {
+ data = extractor.Reader.Read(TargetBlockSizeRecords);
+ }
+
+ //check if block reader could read anything, if not, return null
+ if (data == null)
+ {
+ _lastAvailableBlockIndex = blockIndex - 1;
+
+ _logger.Debug($"Cannot load block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached.");
+
+ return null;
+ }
+
+ rawDataPerExtractor.Add(data);
+ }
+
+ if (blockIndex > _lastReadRawDataBlockIndex)
+ {
+ _lastReadRawDataBlockIndex = blockIndex;
+ }
+
+ return rawDataPerExtractor.ToArray();
+ }
+
+ private Dictionary ExtractDirectFrom(object[] data, int blockIndex, IComputationHandler handler)
+ {
+ Dictionary namedBlocks = new Dictionary();
+
+ ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare, "preparing extractors for dataset \"" + Name + "\"", indeterminate: true);
+
+ PrepareExtractors();
+
+ SigmaEnvironment.TaskManager.EndTask(prepareTask);
+
+ ITaskObserver extractTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Extract, $"extracting block {blockIndex} for dataset \"{Name}\"", indeterminate: true);
+
+ int extractorIndex = 0;
+ foreach (IRecordExtractor extractor in _recordExtractors)
+ {
+ _logger.Debug($"Extracting hierarchically from extractor {extractor} at index {extractorIndex}...");
+
+ Dictionary subNamedBlock = extractor.ExtractHierarchicalFrom(data[extractorIndex++], TargetBlockSizeRecords, handler);
+
+ //check if block size is 0, indicating we reached the end of the stream
+ if (subNamedBlock == null)
+ {
+ _lastAvailableBlockIndex = blockIndex - 1;
+
+ _logger.Debug($"Cannot extract block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached.");
+
+ SigmaEnvironment.TaskManager.CancelTask(extractTask);
+
+ return null;
+ }
+
+ foreach (string name in subNamedBlock.Keys)
+ {
+ if (namedBlocks.ContainsKey(name))
+ {
+ SigmaEnvironment.TaskManager.CancelTask(extractTask);
+
+ throw new ArgumentException($"Section name collision: {name} is already used by another extractor, current extractor {extractor} cannot use it again.");
+ }
+ else
+ {
+ namedBlocks.Add(name, subNamedBlock[name]);
+ }
+ }
+ }
+
+ SigmaEnvironment.TaskManager.EndTask(extractTask);
+
+ return namedBlocks;
+ }
+
+ public void FreeBlock(int blockIndex, IComputationHandler handler)
+ {
+ lock (_activeBlocks)
+ {
+ if (!_activeBlocks.ContainsKey(blockIndex))
+ {
+ _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active.");
+
+ return;
+ }
+
+ RecordBlock toRemove = null;
+
+ foreach (RecordBlock block in _activeBlocks[blockIndex])
+ {
+ if (ReferenceEquals(block.Handler, handler))
+ {
+ _logger.Debug($"Freeing block with index {blockIndex} for handler {handler}...");
+
+ CacheBlockConstrained(block.NamedBlockSections, blockIndex, handler);
+
+ //_availableBlocksSemaphore.Release();
+ //_availableBlocksSemaphoreState++;
+
+ toRemove = block;
+
+ goto FoundBlock;
+ }
+ }
+
+ _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active.");
+
+ FoundBlock:
+
+ DeregisterActiveBlock(toRemove);
+ _logger.Debug($"Done freeing block with index {blockIndex} for handler {handler}.");
+
+ }
+ }
+
+ private void CacheBlockConstrained(Dictionary block, int blockIndex, IComputationHandler handler)
+ {
+ if (_cachedBlocks.ContainsKey(blockIndex))
+ {
+ foreach (WeakRecordBlock cachedBlock in _cachedBlocks[blockIndex])
+ {
+ //check if block of the same type and size is already cached, if so, return, because there is no need to cache again
+ if (cachedBlock.BlockIndex == blockIndex && cachedBlock.Handler.IsInterchangeable(handler) && block.First().Value.Shape[0] == cachedBlock.NumberRecords)
+ {
+ _logger.Debug($"Skipping cache request of block {blockIndex} for handler {handler} because interchangeable block of same index, format and size is already cached.");
+
+ return;
+ }
+ }
+ }
+
+ long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray());
+
+ if (_cachedBlocks.Count >= MaxBlocksInCache)
+ {
+ _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBlocksInCache constraint of {MaxBlocksInCache}.");
+
+ return;
+ }
+
+ if (blockSizeBytes + _totalCachedBlockSizeBytes >= MaxBytesInCache)
+ {
+ _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBytesInCache constraint of {MaxBytesInCache} bytes (block of size {blockSizeBytes} would exceed constraint by {_totalCachedBlockSizeBytes + blockSizeBytes - MaxBytesInCache} bytes).");
+
+ return;
+ }
+
+ string cacheIdentifier = $"extracted.{blockIndex}.{handler.DataType.Identifier}";
+
+ _cacheProvider.Store(cacheIdentifier, block);
+
+ bool keepReference = TotalActiveBlockSizeBytes + blockSizeBytes < MaxTotalActiveBlockSizeBytes;
+
+ RegisterCachedBlock(block, blockIndex, handler, keepReference);
+
+ _totalCachedBlockSizeBytes += blockSizeBytes;
+ }
+
+ private void PrepareExtractors()
+ {
+ foreach (IRecordExtractor extractor in _recordExtractors)
+ {
+ lock (extractor)
+ {
+ extractor.Prepare();
+ }
+ }
+ }
+
+ public long GetBlockSizeBytes(int blockIndex, IComputationHandler handler)
+ {
+ if (!_activeBlocks.ContainsKey(blockIndex))
+ {
+ return -1L;
+ }
+
+ foreach (RecordBlock block in _activeBlocks[blockIndex])
+ {
+ if (ReferenceEquals(block.Handler, handler))
+ {
+ return block.EstimatedSizeBytes;
+ }
+ }
+
+ return -1L;
+ }
+
+ public bool IsBlockActive(int blockIndex)
+ {
+ return _activeBlocks.ContainsKey(blockIndex);
+ }
+
+ public bool IsBlockActive(int blockIndex, IComputationHandler handler)
+ {
+ if (!_activeBlocks.ContainsKey(blockIndex))
+ {
+ return false;
+ }
+
+ foreach (RecordBlock block in _activeBlocks[blockIndex])
+ {
+ if (ReferenceEquals(block.Handler, handler))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ private bool IsBlockCached(int blockIndex, IComputationHandler handler)
+ {
+ if (!_cachedBlocks.ContainsKey(blockIndex))
+ {
+ return false;
+ }
+
+ foreach (WeakRecordBlock block in _cachedBlocks[blockIndex])
+ {
+ if (ReferenceEquals(block.Handler, handler))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ public void Dispose()
+ {
+ foreach (IRecordExtractor extractor in _recordExtractors)
+ {
+ extractor.Dispose();
+ extractor.Reader?.Dispose();
+ }
+
+ _cacheProvider.Dispose();
+ }
+
+ public static IDataset[] SplitBlockwise(IDataset dataset, params int[] parts)
+ {
+ if (parts.Length == 0)
+ {
+ throw new ArgumentException("Parts cannot be an empty collection.");
+ }
+
+ int splitInterval = parts.Sum();
+ int lastEnd = 0;
+ IDataset[] slices = new IDataset[parts.Length];
+
+ for (int i = 0; i < parts.Length; i++)
+ {
+ slices[i] = new DatasetBlockwiseSlice(dataset, lastEnd, lastEnd + parts[i] - 1, splitInterval);
+ lastEnd += parts[i];
+ }
+
+ return slices;
+ }
+
+ public static IDataset[] SplitRecordwise(IDataset dataset, params double[] parts)
+ {
+ if (parts.Length == 0)
+ {
+ throw new ArgumentException("Percentages cannot be an empty collection.");
+ }
+
+ if (parts.Sum() > 1.0)
+ {
+ throw new ArgumentException($"Percentages sum cannot be > 1.0, but parts sum was {parts.Sum()}.");
+ }
+
+ IDataset[] slices = new IDataset[parts.Length];
+
+ double lastOffset = 0.0;
+
+ for (int i = 0; i < slices.Length; i++)
+ {
+ slices[i] = new DatasetRecordwiseSlice(dataset, lastOffset, parts[i]);
+
+ lastOffset += parts[i];
+ }
+
+ return slices;
+ }
+
+ internal abstract class RecordBlockBase
+ {
+ internal abstract Dictionary NamedBlockSections { get; set; }
+ internal abstract INDArray FirstNamedBlock { get; set; }
+ internal abstract bool Loaded { get; set; }
+
+ internal IComputationHandler Handler;
+ internal bool Active;
+ internal int BlockIndex;
+ internal long NumberRecords;
+ internal long EstimatedSizeBytes;
+ }
+
+ internal class RecordBlock : RecordBlockBase
+ {
+ internal sealed override Dictionary NamedBlockSections { get; set; }
+ internal sealed override INDArray FirstNamedBlock { get; set; }
+ internal override bool Loaded { get; set; }
+
+ public RecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler)
+ {
+ NamedBlockSections = namedBlockSections;
+ BlockIndex = blockIndex;
+ NumberRecords = numberRecords;
+ EstimatedSizeBytes = estimatedSizeBytes;
+ Handler = handler;
+
+ //record blocks internal block can be null
+ if (namedBlockSections != null)
+ {
+ FirstNamedBlock = namedBlockSections[namedBlockSections.First().Key];
+ }
+ }
+ }
+
+ internal class WeakRecordBlock : RecordBlockBase
+ {
+ internal override Dictionary NamedBlockSections
+ {
+ get
+ {
+ Dictionary target;
+
+ return _namedBlockSections.TryGetTarget(out target) ? target : null;
+ }
+ set
+ {
+ _namedBlockSections.SetTarget(value);
+ }
+ }
+
+ internal override INDArray FirstNamedBlock
+ {
+ get
+ {
+ INDArray target;
+
+ return _firstNamedBlock.TryGetTarget(out target) ? target : null;
+ }
+ set
+ {
+ _firstNamedBlock.SetTarget(value);
+ }
+ }
+
+ internal override bool Loaded
+ {
+ get
+ {
+ Dictionary target;
+
+ return _namedBlockSections.TryGetTarget(out target);
+ }
+ set
+ {
+ }
+ }
+
+ private readonly WeakReference> _namedBlockSections;
+ private readonly WeakReference _firstNamedBlock;
+
+ public WeakRecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler)
+ {
+ _namedBlockSections = new WeakReference>(namedBlockSections);
+ BlockIndex = blockIndex;
+ NumberRecords = numberRecords;
+ EstimatedSizeBytes = estimatedSizeBytes;
+ Handler = handler;
+
+ //record blocks internal block can be null
+ if (namedBlockSections != null)
+ {
+ _firstNamedBlock = new WeakReference(namedBlockSections[namedBlockSections.First().Key]);
+ }
+ }
+ }
+
+ public override string ToString()
+ {
+ return $"dataset \"{Name}\"";
+ }
+ }
+}
diff --git a/Sigma.Core/Data/Datasets/IDataset.cs b/Sigma.Core/Data/Datasets/IDataset.cs
index a4ffc1c2..558b9fbc 100644
--- a/Sigma.Core/Data/Datasets/IDataset.cs
+++ b/Sigma.Core/Data/Datasets/IDataset.cs
@@ -76,7 +76,7 @@ public interface IDataset : IDisposable
///
/// The number of currently active and loaded record blocks, with different block formats counting as different blocks.
///
- int ActiveIndividualBlockCount { get; }
+ int ActiveIndividualBlockRegionCount { get; }
///
/// The number of currently active and loaded record blocks, with different block formats of the same region counting as one active block index.
diff --git a/Sigma.Core/Data/Datasets/RawDataset.cs b/Sigma.Core/Data/Datasets/RawDataset.cs
new file mode 100644
index 00000000..5ac3e388
--- /dev/null
+++ b/Sigma.Core/Data/Datasets/RawDataset.cs
@@ -0,0 +1,335 @@
+/*
+MIT License
+
+Copyright(c) 2016-2017 Florian Cäsar, Michael Plainer
+
+For full license see LICENSE in the root directory of this project.
+*/
+
+
+using System;
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Sigma.Core.Handlers;
+using Sigma.Core.Handlers.Backends.SigmaDiff.NativeCpu;
+using Sigma.Core.MathAbstract;
+using Sigma.Core.Utils;
+
+namespace Sigma.Core.Data.Datasets
+{
+ ///
+ /// A raw in-system-memory dataset which can be manually
+ ///
+ [Serializable]
+ public class RawDataset : IDataset
+ {
+ private readonly IComputationHandler _internalHandler;
+
+ ///
+ /// The name and identifier of this dataset.
+ /// Dataset names should be globally unique and easily identifiable.
+ ///
+ public string Name { get; }
+
+ ///
+ /// Indicate if this dataset is an online dataset (meaning new data might be added during runtime).
+ /// By default, this is assumed to be false, indicating a static dataset.
+ /// Note: Data iterators and may perform certain optimisations for static datasets, so set this to false if possible.
+ ///
+ public bool Online { get; set; }
+
+ ///
+ /// The preferred per block size in records.
+ /// Note: Not every block must obey this request (e.g. the last black might very well be a different size).
+ ///
+ public int TargetBlockSizeRecords { get; }
+
+ ///
+ /// The maximum number of concurrently active blocks.
+ ///
+ public int MaxConcurrentActiveBlocks { get; }
+
+ ///
+ /// The maximum total concurrently active block size in bytes.
+ ///
+ public long MaxTotalActiveBlockSizeBytes { get; }
+
+ ///
+ /// The total size of all currently active record blocks in system memory in bytes.
+ ///
+ public long TotalActiveBlockSizeBytes { get; }
+
+ ///
+ /// The maxmimum number of blocks to keep in the cache (inactive blocks are written to a cache, typically on disk, to be reloaded later).
+ ///
+ public int MaxBlocksInCache { get { throw new NotSupportedException(); } set { throw new NotSupportedException(); } }
+
+ ///
+ /// The maxmimum number of bytes to keep in the cache (inactive blocks are written to a cache, typically on disk, to be reloaded later).
+ ///
+ public long MaxBytesInCache { get { throw new NotSupportedException(); } set { throw new NotSupportedException(); } }
+
+ ///
+ /// The names for all sections present in this dataset (e.g. "inputs", "targets").
+ ///
+ public string[] SectionNames { get; }
+
+ ///
+ /// A set of currently active and loaded record block indices.
+ ///
+ public IReadOnlyCollection ActiveBlockIndices { get; }
+
+ ///
+ /// The number of currently active and loaded record blocks, with different block formats counting as different blocks.
+ ///
+ public int ActiveIndividualBlockRegionCount => _rawData.Count;
+
+ ///
+ /// The number of currently active and loaded record blocks, with different block formats of the same region counting as one active block index.
+ ///
+ public int ActiveBlockRegionCount => _rawData.Count & 0x1; // can only be 1 active block
+
+ ///
+ /// The current working data that can be edited and is "flushed" to the public raw data with the next call.
+ ///
+ private readonly IDictionary _internalWorkingData;
+
+ ///
+ /// The raw data of this dataset that is returned via the functions.
+ ///
+ private readonly IDictionary> _rawData;
+
+ ///
+ /// Create a raw dataset with a certain name and an internal cpu handler with 32-bit float precision.
+ ///
+ /// The globally unique name of this dataset.
+ public RawDataset(string name) : this(name, new CpuFloat32Handler())
+ {
+ }
+
+ ///
+ /// Create a raw dataset with a certain name and computation handler.
+ ///
+ /// The globally unique name of this dataset.
+ /// The internal handler to use for data management.
+ public RawDataset(string name, IComputationHandler internalHandler)
+ {
+ if (name == null) throw new ArgumentNullException(nameof(name));
+ if (internalHandler == null) throw new ArgumentNullException(nameof(internalHandler));
+
+ Name = name;
+ _internalHandler = internalHandler;
+
+ _internalWorkingData = new ConcurrentDictionary();
+ _rawData = new ConcurrentDictionary>();
+ }
+
+ ///
+ /// Add a record to a certain block.
+ /// Note: Feature shape is length of record, no time dimension, auto batch dimension.
+ ///
+ /// The record data type (must be primitive data type).
+ /// The block name (e.g. "inputs").
+ /// The record.
+ public void AddRecord(string blockName, params T[] record)
+ {
+ AddShapedRecords(blockName, new long[] { record.Length }, new[] { record });
+ }
+
+ ///
+ /// Add a record with a certain feature shape to a certain block.
+ /// Note: Feature shape is as specified, no time dimension, auto batch dimension.
+ ///
+ /// The record data type (must be primitive data type).
+ /// The block name (e.g. "inputs").
+ /// The feature shape.
+ /// The record.
+ public void AddRecord(string blockName, long[] featureShape, params T[] record)
+ {
+ AddShapedRecords(blockName, featureShape, record);
+ }
+
+ ///
+ /// Add records to a certain block.
+ /// Note: Feature shape is length of record, no time dimension, auto batch dimension.
+ ///
+ /// The record data type (must be primitive data type).
+ /// The block name (e.g. "inputs").
+ /// The records.
+ public void AddRecords(string blockName, params T[][] records)
+ {
+ AddShapedRecords(blockName, new long[] { records[0].Length }, records);
+ }
+
+ ///
+ /// Add records with a certain feature shape to a certain block.
+ /// Note: Feature shape is as specified, no time dimension, auto batch dimension.
+ ///
+ /// The record data type (must be primitive data type).
+ /// The block name (e.g. "inputs").
+ /// The feature shape.
+ /// The records.
+ public void AddShapedRecords(string blockName, long[] featureShape, params T[][] records)
+ {
+ if (records.Length == 0)
+ {
+ return;
+ }
+
+ long featureLength = ArrayUtils.Product(featureShape);
+ long[] newShape = ArrayUtils.Concatenate(new long[] { records.Length, 1 }, featureShape); // BatchTimeFeatures shape order, time dimension is not supported at the moment
+ long[] insertedShape = (long[])newShape.Clone();
+ bool previousBlockExists = _internalWorkingData.ContainsKey(blockName);
+
+ if (previousBlockExists)
+ {
+ newShape[0] += _internalWorkingData[blockName].Shape[0]; // append new record to end
+ }
+
+ INDArray newBlock = _internalHandler.NDArray(newShape);
+
+ long[] destinationBegin = new long[newBlock.Rank];
+
+ if (previousBlockExists)
+ {
+ INDArray oldBlock = _internalWorkingData[blockName];
+
+ for (int i = 1; i < oldBlock.Shape.Length; i++)
+ {
+ if (newShape[i] != oldBlock.Shape[i])
+ {
+ throw new InvalidOperationException($"Shape mismatch: already existing block for \"{blockName}\" has shape {ArrayUtils.ToString(oldBlock.Shape)} but new block has shape {ArrayUtils.ToString(newShape)}");
+ }
+ }
+
+ long[] previousSourceBegin = new long[oldBlock.Rank];
+ long[] previousSourceEnd = oldBlock.Shape.Select(i => i - 1).ToArray();
+
+ _internalHandler.Fill(oldBlock, newBlock, previousSourceBegin, previousSourceEnd, previousSourceBegin, previousSourceEnd);
+
+ destinationBegin[0] = oldBlock.Shape[0];
+ }
+
+ long[] destinationEnd = insertedShape.Select(i => i - 1).ToArray();
+ destinationEnd[0] = destinationBegin[0];
+
+ for (int i = 0; i < records.Length; i++)
+ {
+ _internalHandler.Fill(records[i], newBlock, destinationBegin, destinationEnd);
+
+ destinationBegin[0]++;
+ destinationEnd[0]++;
+ }
+
+ if (previousBlockExists)
+ {
+ _internalWorkingData[blockName] = newBlock;
+ }
+ else
+ {
+ _internalWorkingData.Add(blockName, newBlock);
+ }
+ }
+
+ ///
+ public IDictionary FetchBlock(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true)
+ {
+ if (blockIndex != 0 || _internalWorkingData.Count == 0)
+ {
+ return null; // there is only 1 block in this raw dataset implementation (so if there's no data, there's no block)
+ }
+
+ if (IsBlockActive(blockIndex, handler))
+ {
+ return _rawData[handler];
+ }
+
+ if (!handler.CanConvert(_internalWorkingData.Values.First(), _internalHandler))
+ {
+ return null;
+ }
+
+ IDictionary convertedBlock = new Dictionary();
+
+ foreach (string blockName in _internalWorkingData.Keys)
+ {
+ convertedBlock[blockName] = handler.Convert(_internalWorkingData[blockName], _internalHandler);
+ }
+
+ _rawData.Add(handler, convertedBlock);
+
+ return convertedBlock;
+ }
+
+ ///
+ public async Task> FetchBlockAsync(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true)
+ {
+ return await Task.Run(() => FetchBlock(blockIndex, handler, shouldWaitUntilAvailable));
+ }
+
+ ///
+ public void FreeBlock(int blockIndex, IComputationHandler handler)
+ {
+ if (IsBlockActive(blockIndex, handler))
+ {
+ _rawData.Remove(handler);
+ }
+ }
+
+ ///
+ public bool IsBlockActive(int blockIndex)
+ {
+ return blockIndex == 0;
+ }
+
+ ///
+ public bool IsBlockActive(int blockIndex, IComputationHandler handler)
+ {
+ return blockIndex == 0 && _rawData.ContainsKey(handler);
+ }
+
+ ///
+ public long GetBlockSizeBytes(int blockIndex, IComputationHandler handler)
+ {
+ if (!IsBlockActive(blockIndex, handler))
+ {
+ return -1L;
+ }
+
+ // TODO
+ throw new NotImplementedException();
+ }
+
+ ///
+ public bool CanFetchBlocksAfter(int blockIndex)
+ {
+ return blockIndex == -1;
+ }
+
+ ///
+ public bool TrySetBlockSize(int blockSizeRecords)
+ {
+ throw new NotSupportedException();
+ }
+
+ ///
+ public IDataset[] SplitBlockwise(params int[] parts)
+ {
+ return ExtractedDataset.SplitBlockwise(this, parts);
+ }
+
+ ///
+ public IDataset[] SplitRecordwise(params double[] percentages)
+ {
+ return ExtractedDataset.SplitRecordwise(this, percentages);
+ }
+
+ /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+ public void Dispose()
+ {
+ }
+ }
+}
diff --git a/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs b/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs
index f6b57adb..de71a009 100644
--- a/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs
+++ b/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs
@@ -28,12 +28,22 @@ public abstract class BaseAdaptivePreprocessor : BasePreprocessor
private readonly TPreprocessor _underlyingPreprocessor;
private bool _initialAdaptionComplete;
+ ///
+ /// Create a base adaptive preprocessor with a certain underlying preprocessor (that will be adapted).
+ ///
+ /// The underlying preprocessor.
+ /// The section names to process in this preprocessor (all if null or empty).
+ protected BaseAdaptivePreprocessor(TPreprocessor underlyingPreprocessor, params string[] sectionNames) : this(underlyingPreprocessor, AdaptionRate.Initial, sectionNames)
+ {
+ }
+
///
/// Create a base adaptive preprocessor with a certain underlying preprocessor (that will be adapted).
///
/// The underlying preprocessor.
/// The adaption rate.
- protected BaseAdaptivePreprocessor(TPreprocessor underlyingPreprocessor, AdaptionRate adaptionRate = AdaptionRate.Every)
+ /// The section names to process in this preprocessor (all if null or empty).
+ protected BaseAdaptivePreprocessor(TPreprocessor underlyingPreprocessor, AdaptionRate adaptionRate, params string[] sectionNames) : base(sectionNames)
{
if (underlyingPreprocessor == null) throw new ArgumentNullException(nameof(underlyingPreprocessor));
if (!Enum.IsDefined(typeof(AdaptionRate), adaptionRate)) throw new InvalidEnumArgumentException(nameof(adaptionRate), (int) adaptionRate, typeof(AdaptionRate));
diff --git a/Sigma.Core/Data/Preprocessors/ShufflePreprocessor.cs b/Sigma.Core/Data/Preprocessors/ShufflePreprocessor.cs
new file mode 100644
index 00000000..b2f84fbf
--- /dev/null
+++ b/Sigma.Core/Data/Preprocessors/ShufflePreprocessor.cs
@@ -0,0 +1,75 @@
+using System;
+using ManagedCuda.VectorTypes;
+using Sigma.Core.Handlers;
+using Sigma.Core.MathAbstract;
+
+namespace Sigma.Core.Data.Preprocessors
+{
+ [Serializable]
+ public class ShufflePreprocessor : BasePreprocessor
+ {
+ ///
+ /// The dimension along which should be shuffled.
+ /// Note: Must be >= 0.
+ ///
+ public int AlongDimension
+ {
+ get { return _alongDimension; }
+ set
+ {
+ if (value < 0) throw new ArgumentException($"Along dimension must be >= 0 but given value was {value}.");
+
+ _alongDimension = value;
+ }
+ }
+
+ ///
+ public override bool AffectsDataShape => false;
+
+ private Random _random;
+ private int _alongDimension;
+
+ ///
+ /// Create a shuffle preprocessor and optionally specify the dominant dimension (along which should be shuffled, batch dimension (0) by default).
+ ///
+ ///
+ public ShufflePreprocessor(int alongDimension = 0)
+ {
+ AlongDimension = alongDimension;
+ }
+
+ ///
+ /// Process a certain ndarray with a certain computation handler.
+ ///
+ /// The ndarray to process.
+ /// The computation handler to do the processing with.
+ /// An ndarray with the processed contents of the given array (can be the same or a new one).
+ internal override INDArray ProcessDirect(INDArray array, IComputationHandler handler)
+ {
+ int recordLength = (int) (array.Length / array.Shape[0]);
+ long[] firstBufferIndices = new long[array.Shape.Length];
+ long[] secondBufferIndices = new long[array.Shape.Length];
+
+ _random = new Random(31415926); // fixed rng for reproducability
+
+ for (int i = 0; i < array.Shape[0]; i++)
+ {
+ int swapIndex = _random.Next((int) array.Shape[0]);
+
+ for (int y = 0; y < recordLength; y++)
+ {
+ NDArrayUtils.GetIndices(recordLength * i + y, array.Shape, array.Strides, firstBufferIndices);
+ NDArrayUtils.GetIndices(recordLength * swapIndex + y, array.Shape, array.Strides, secondBufferIndices);
+
+ double firstValue = array.GetValue(firstBufferIndices);
+ double secondValue = array.GetValue(secondBufferIndices);
+
+ array.SetValue(secondValue, firstBufferIndices);
+ array.SetValue(firstValue, secondBufferIndices);
+ }
+ }
+
+ return array;
+ }
+ }
+}
diff --git a/Sigma.Core/Dependencies/DiffSharp.dll b/Sigma.Core/Dependencies/DiffSharp.dll
index ee9e31cf..0e868b5c 100644
Binary files a/Sigma.Core/Dependencies/DiffSharp.dll and b/Sigma.Core/Dependencies/DiffSharp.dll differ
diff --git a/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs b/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs
index 8453f69c..b2dc1ed3 100644
--- a/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs
+++ b/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs
@@ -76,8 +76,8 @@ public DebugHandler(IComputationHandler underlyingHandler, bool throwExceptionOn
// kind of ugly but saves me from writing more solid property handling
ThrowExceptionOnReport = throwExceptionOnReport;
Enabled = enabled;
- CheckNaN = false;
- CheckInfinite = false;
+ CheckNaN = enabled;
+ CheckInfinite = enabled;
}
private void Report(string message, params object[] values)
@@ -99,23 +99,23 @@ private INDArray CheckNice(INDArray array, string paramName = "unspecified")
if (array == null)
{
- Report($"ndarray {paramName} is null.");
+ Report($"ndarray \"{paramName}\" is null.");
}
else
{
if (array.Rank != array.Shape.Length)
{
- Report($"ndarray {paramName} has inconsistent rank ({array.Rank}) / shape (length {array.Length}).", array);
+ Report($"ndarray \"{paramName}\" has inconsistent rank ({array.Rank}) / shape (length {array.Length}).", array);
}
- if (CheckNaN && !UnderlyingHandler.IsNaN(array))
+ if (CheckNaN && UnderlyingHandler.IsNaN(array))
{
- Report($"ndarray {paramName} contains NaN values.", array);
+ Report($"ndarray \"{paramName}\" contains NaN values.", array);
}
- if (CheckInfinite && !UnderlyingHandler.IsNotFinite(array))
+ if (CheckInfinite && UnderlyingHandler.IsNotFinite(array))
{
- Report($"ndarray {paramName} contains infinite values.", array);
+ Report($"ndarray \"{paramName}\" contains infinite values.", array);
}
}
@@ -135,12 +135,12 @@ private INumber CheckNice(INumber number, string paramName = "unspecified")
}
else
{
- if (CheckNaN && !UnderlyingHandler.IsNaN(number))
+ if (CheckNaN && UnderlyingHandler.IsNaN(number))
{
Report($"number {paramName} is a NaN value.", number);
}
- if (CheckInfinite && !UnderlyingHandler.IsNotFinite(number))
+ if (CheckInfinite && UnderlyingHandler.IsNotFinite(number))
{
Report($"number {paramName} is an infinite value.", number);
}
@@ -278,11 +278,25 @@ public void Fill(INDArray filler, INDArray arrayToFill)
public void Fill(TOther value, INDArray arrayToFill)
{
- UnderlyingHandler.Fill(value, CheckNice(arrayToFill));
+ UnderlyingHandler.Fill(value, arrayToFill);
CheckNice(arrayToFill);
}
+ public void Fill(INDArray filler, INDArray arrayToFill, long[] sourceBeginIndices, long[] sourceEndIndices, long[] destinationBeginIndices, long[] destinationEndIndices)
+ {
+ UnderlyingHandler.Fill(CheckNice(filler), CheckNice(arrayToFill), sourceBeginIndices, sourceEndIndices, destinationBeginIndices, destinationEndIndices);
+
+ CheckNice(arrayToFill);
+ }
+
+ public void Fill(T[] filler, INDArray arrayToFill, long[] destinationBeginIndices, long[] destinationEndIndices)
+ {
+ UnderlyingHandler.Fill(filler, CheckNice(arrayToFill), destinationBeginIndices, destinationEndIndices);
+
+ CheckNice(arrayToFill);
+ }
+
public INDArray FlattenTime(INDArray array)
{
if (Enabled && array.Rank < 2) // two or three? technically 2 is enough ([BT]F) but 3 makes more sense
diff --git a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs
index d91d8eb8..4a9e14ab 100644
--- a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs
+++ b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs
@@ -53,9 +53,7 @@ internal DiffSharpBackendHandle(IBlasBackend blasBackend, ILapackBackend lapackB
public abstract FSharpOption> Solve_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b);
public abstract FSharpOption> SolveSymmetric_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b);
public abstract ISigmaDiffDataBuffer Diagonal_M(ShapedDataBufferView a);
- public abstract ISigmaDiffDataBuffer Map_F_V(FSharpFunc a, ISigmaDiffDataBuffer b);
- public abstract ISigmaDiffDataBuffer Map2_F_V_V(FSharpFunc> a, ISigmaDiffDataBuffer b, ISigmaDiffDataBuffer obj2);
- public abstract ISigmaDiffDataBuffer ReshapeCopy_MRows_V(ShapedDataBufferView value);
+ public abstract ISigmaDiffDataBuffer ReshapeCopy_MRows_V(ShapedDataBufferView value);
public abstract ShapedDataBufferView Mul_Out_V_V(ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b);
public abstract ShapedDataBufferView Add_M_M(ShapedDataBufferView a, ShapedDataBufferView b);
public abstract ShapedDataBufferView Add_S_M(T a, ShapedDataBufferView b);
@@ -70,10 +68,14 @@ internal DiffSharpBackendHandle(IBlasBackend blasBackend, ILapackBackend lapackB
public abstract FSharpOption> Inverse_M(ShapedDataBufferView a);
public abstract FSharpOption Det_M(ShapedDataBufferView a);
public abstract ShapedDataBufferView Transpose_M(ShapedDataBufferView a);
- public abstract ShapedDataBufferView Map_F_M(FSharpFunc a, ShapedDataBufferView b);
- public abstract ShapedDataBufferView Map2_F_M_M(FSharpFunc> a, ShapedDataBufferView b, ShapedDataBufferView obj2);
public abstract ShapedDataBufferView ReshapeCopy_V_MRows(int rows, ISigmaDiffDataBuffer value);
public abstract ShapedDataBufferView RepeatReshapeCopy_V_MRows(int rows, ISigmaDiffDataBuffer value);
public abstract ShapedDataBufferView RepeatReshapeCopy_V_MCols(int cols, ISigmaDiffDataBuffer value);
+ public abstract ISigmaDiffDataBuffer Map_F_V(MapOp mapOp, FSharpFunc function, ISigmaDiffDataBuffer value);
+ public abstract ISigmaDiffDataBuffer Map_F_S_V(T other, MapOp mapOp, FSharpFunc function, ISigmaDiffDataBuffer value);
+ public abstract ISigmaDiffDataBuffer Map2_F_V_V(MapOp mapOp, FSharpFunc> function, ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b);
+ public abstract ShapedDataBufferView Map_F_M(MapOp mapOp, FSharpFunc function, ShapedDataBufferView