Skip to content

Commit

Permalink
Correct access to generated variables in Python
Browse files Browse the repository at this point in the history
A new test specific for Node related functionality was added.

Re ECFLOW-1968
  • Loading branch information
marcosbento committed Jul 22, 2024
1 parent de56499 commit cb874be
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
1 change: 1 addition & 0 deletions libs/pyext/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ set(u_tests
py_u_test_defs_constructor
py_u_test_get_attr
py_u_test_manual
py_u_test_node
py_u_test_late
py_u_test_replace_node
py_u_test_tutorial
Expand Down
12 changes: 11 additions & 1 deletion libs/pyext/src/ecflow/python/ExportNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,16 @@ node_ptr add_defstatus1(node_ptr self, const Defstatus& ds) {
return self;
}

bp::list generated_variables(node_ptr self) {
bp::list list;
std::vector<Variable> vec;
self->gen_variables(vec);
for (const auto& i : vec) {
list.append(i);
}
return list;
}

/////////////////////////////////////////////////////////////////////////////////////////

static object do_rshift(node_ptr self, const bp::object& arg) {
Expand Down Expand Up @@ -677,7 +687,7 @@ void export_Node() {
.def("has_time_dependencies", &Node::hasTimeDependencies)
.def("update_generated_variables", &Node::update_generated_variables)
.def("get_generated_variables",
&Node::gen_variables,
&generated_variables,
"returns a list of generated variables. Use ecflow.VariableList as return argument")
.def("is_suspended", &Node::isSuspended, "Returns true if the `node`_ is in a `suspended`_ state")
.def("find_variable",
Expand Down
80 changes: 80 additions & 0 deletions libs/pyext/test/py_u_test_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Copyright 2009- ECMWF.
#
# This software is licensed under the terms of the Apache Licence version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import ecflow_test_util as ect
import ecflow as ecf
import unittest
import sys
import os


class TestNode(unittest.TestCase):

def setUp(self):
self.defs = ecf.Defs()
self.suite = self.defs.add_suite("suite");

self.suite.add_variable("VARIABLE", "value");

self.family = self.suite.add_family("family")
self.family.add_repeat(ecf.RepeatDate("REPEAT", 20010101, 20010102, 1))

self.task = self.family.add_task("task")

def test_is_able_to_retrieve_suite_generated_variables(self):
vars = self.suite.get_generated_variables()

names = [v.name() for v in vars]
self.assertIn("SUITE", names)
self.assertIn("ECF_DATE", names)
self.assertIn("ECF_CLOCK", names)
self.assertIn("ECF_TIME", names)
self.assertIn("ECF_JULIAN", names)

self.assertIn("YYYY", names)
self.assertIn("DD", names)
self.assertIn("MM", names)
self.assertIn("DAY", names)
self.assertIn("MONTH", names)
self.assertIn("DATE", names)
self.assertIn("TIME", names)
self.assertIn("DOW", names)
self.assertIn("DOY", names)

def test_is_able_to_retrieve_family_generated_variables(self):
vars = self.family.get_generated_variables()

names = [v.name() for v in vars]
self.assertIn("FAMILY", names)
self.assertIn("FAMILY1", names)
self.assertIn("REPEAT", names)
self.assertIn("REPEAT_YYYY", names)
self.assertIn("REPEAT_MM", names)
self.assertIn("REPEAT_DD", names)
self.assertIn("REPEAT_DOW", names)
self.assertIn("REPEAT_JULIAN", names)

def test_is_able_to_retrieve_task_generated_variables(self):
vars = self.task.get_generated_variables()

names = [v.name() for v in vars]
self.assertIn("TASK", names)
self.assertIn("ECF_JOB", names)
self.assertIn("ECF_SCRIPT", names)
self.assertIn("ECF_JOBOUT", names)
self.assertIn("ECF_TRYNO", names)
self.assertIn("ECF_RID", names)
self.assertIn("ECF_PASS", names)
self.assertIn("ECF_NAME", names)


if __name__ == "__main__":
unittest.main()
print("All Tests pass")

0 comments on commit cb874be

Please sign in to comment.