Commit dbfbbb2b authored by Jelte Jansen's avatar Jelte Jansen

support for list indices wrt set/add/remove operations


git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac405@3663 e5f2f494-b856-4b98-b285-d166d9295462
parent 96a12cba
......@@ -553,9 +553,6 @@ class BindCmdInterpreter(Cmd):
if cmd.command == "show":
values = self.config_data.get_value_maps(identifier)
print("[XX] VALUE MAPS:")
print(str(values))
print("[XX] END VALUE MAPS")
for value_map in values:
line = value_map['name']
if value_map['type'] in [ 'module', 'map', 'list' ]:
......
......@@ -21,6 +21,7 @@
#
import json
import re
class DataNotFoundError(Exception): pass
class DataTypeError(Exception): pass
......@@ -55,13 +56,67 @@ def merge(orig, new):
else:
orig[kn] = new[kn]
def _split_identifier(identifier):
def _concat_identifier(id_parts):
"""Concatenates the given identifier parts into a string,
delimited with the '/' character.
"""
return '/'.join(id_parts)
def split_identifier(identifier):
"""Splits the given identifier into a list of identifier parts,
as delimited by the '/' character.
Raises a DataTypeError if identifier is not a string."""
if type(identifier) != str:
raise DataTypeError("identifier is not a string")
id_parts = identifier.split("/")
id_parts = identifier.split('/')
id_parts[:] = (value for value in id_parts if value != "")
return id_parts
def split_identifier_list_indices(identifier):
"""Finds list indexes in the given identifier, which are of the
format [integer].
Identifier must be a string.
This will only give the list index for the last 'part' of the
given identifier (as delimited by the '/' sign).
Raises a DataTypeError if the identifier is not a string,
or if the format is bad.
Returns a tuple, where the first element is the string part of
the identifier, and the second element is a list of (nested) list
indices.
Examples:
'a/b/c' will return ('a/b/c', None)
'a/b/c[1]' will return ('a/b/c', [1])
'a/b/c[1][2][3]' will return ('a/b/c', [1, 2, 3])
'a[0]/b[1]/c[2]' will return ('a[0]/b[1]/c, [2])
"""
if type(identifier) != str:
raise DataTypeError("identifier in split_identifier_list_indices() contains '/': " + str(identifier))
id_parts = split_identifier(identifier)
id_str = id_parts[-1]
i = id_str.find('[')
if i < 0:
if identifier.find(']') >= 0:
raise DataTypeError("Bad format in identifier: " + str(identifier))
return identifier, None
id = identifier[:i]
indices = []
while i >= 0:
e = id_str.find(']')
if e < i + 1:
raise DataTypeError("Bad format in identifier: " + str(identifier))
try:
indices.append(int(id_str[i+1:e]))
except ValueError:
raise DataTypeError("List index in " + identifier + " not an integer")
id_str = id_str[e + 1:]
i = id_str.find('[')
if id.find(']') >= 0:
raise DataTypeError("Bad format in identifier: " + str(identifier))
id_parts = id_parts[:-1]
id_parts.append(id)
id = _concat_identifier(id_parts)
return id, indices
def _find_child_el(element, id):
"""Finds the child of element with the given id. If the id contains
[i], where i is a number, and the child element is a list, the
......@@ -71,33 +126,23 @@ def _find_child_el(element, id):
Raises a DataNotFoundError if the element at id could not be
found.
"""
i = id.find('[')
e = id.find(']')
list_index = None
if i >= 0 and e > i + 1:
try:
list_index = int(id[i + 1:e])
except ValueError as ve:
# repack as datatypeerror
raise DataTypeError(ve)
id = id[:i]
id, list_indices = split_identifier_list_indices(id)
if type(element) == dict and id in element.keys():
result = element[id]
else:
raise DataNotFoundError(id + " in " + str(element))
if type(result) == list and list_index is not None:
print("[XX] GETTING ELEMENT NUMBER " + str(list_index) + " (of " + str(len(result)) + ")")
if list_index >= len(result):
print("[XX] OUT OF RANGE")
raise DataNotFoundError("Element " + str(list_index) + " in " + str(result))
result = result[list_index]
if type(result) == list and list_indices is not None:
for list_index in list_indices:
if list_index >= len(result):
raise DataNotFoundError("Element " + str(list_index) + " in " + str(result))
result = result[list_index]
return result
def find(element, identifier):
"""Returns the subelement in the given data element, raises DataNotFoundError if not found"""
if (type(element) != dict and identifier != ""):
raise DataTypeError("element in find() is not a dict")
id_parts = _split_identifier(identifier)
id_parts = split_identifier(identifier)
cur_el = element
for id in id_parts:
cur_el = _find_child_el(cur_el, id)
......@@ -113,17 +158,14 @@ def set(element, identifier, value):
el.set().set().set() is possible)"""
if type(element) != dict:
raise DataTypeError("element in set() is not a dict")
print("[XX] full identifier: " + identifier)
id_parts = _split_identifier(identifier)
if type(identifier) != str:
raise DataTypeError("identifier in set() is not a str")
id_parts = split_identifier(identifier)
cur_el = element
print("[XX] Full element:")
print(element)
for id in id_parts[:-1]:
try:
print("[XX] find " + id)
cur_el = _find_child_el(cur_el, id)
except DataNotFoundError:
print("[XX] DNF for " + id)
if value is None:
# ok we are unsetting a value that wasn't set in
# the first place. Simply stop.
......@@ -131,12 +173,32 @@ def set(element, identifier, value):
cur_el[id] = {}
cur_el = cur_el[id]
# value can be an empty list or dict, so check for None eplicitely
print("[XX] Current value: " + str(cur_el))
if value is not None:
cur_el[id_parts[-1]] = value
elif id_parts[-1] in cur_el:
del cur_el[id_parts[-1]]
id, list_indices = split_identifier_list_indices(id_parts[-1])
if list_indices is None:
# value can be an empty list or dict, so check for None eplicitely
if value is not None:
cur_el[id] = value
else:
del cur_el[id]
else:
cur_el = cur_el[id]
# in case of nested lists, we need to get to the next to last
for list_index in list_indices[:-1]:
if type(cur_el) != list:
raise DataTypeError("Element at " + identifier + " is not a list")
if len(cur_el) <= list_index:
raise DataNotFoundError("List index at " + identifier + " out of range")
cur_el = cur_el[list_index]
# value can be an empty list or dict, so check for None eplicitely
list_index = list_indices[-1]
if type(cur_el) != list:
raise DataTypeError("Element at " + identifier + " is not a list")
if len(cur_el) <= list_index:
raise DataNotFoundError("List index at " + identifier + " out of range")
if value is not None:
cur_el[list_index] = value
else:
del cur_el[list_index]
return element
def unset(element, identifier):
......
......@@ -82,6 +82,23 @@ class TestData(unittest.TestCase):
self.assertRaises(data.DataTypeError, data.merge, 1, d2)
self.assertRaises(data.DataTypeError, data.merge, None, None)
def testsplit_identifier_list_indices(self):
id, indices = data.split_identifier_list_indices('a')
self.assertEqual(id, 'a')
self.assertEqual(indices, None)
id, indices = data.split_identifier_list_indices('a[0]')
self.assertEqual(id, 'a')
self.assertEqual(indices, [0])
id, indices = data.split_identifier_list_indices('a[0][1]')
self.assertEqual(id, 'a')
self.assertEqual(indices, [0, 1])
# bad formats
self.assertRaises(data.DataTypeError, data.split_identifier_list_indices, 'a[')
self.assertRaises(data.DataTypeError, data.split_identifier_list_indices, 'a]')
self.assertRaises(data.DataTypeError, data.split_identifier_list_indices, 'a[[0]]')
def test_find(self):
d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2, 'more': { 'data': 'here' } } }
self.assertEqual(data.find(d1, ''), d1)
......@@ -110,13 +127,27 @@ class TestData(unittest.TestCase):
def test_set(self):
d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2 } }
d12 = { 'b': 1, 'c': { 'e': 3, 'f': [ 1 ] } }
d13 = { 'b': 1, 'c': { 'e': 3, 'f': [ 2 ] } }
d14 = { 'b': 1, 'c': { 'e': 3, 'f': [ { 'g': [ 1, 2 ] } ] } }
d15 = { 'b': 1, 'c': { 'e': 3, 'f': [ { 'g': [ 1, 3 ] } ] } }
data.set(d1, 'a', None)
data.set(d1, 'c/d', None)
data.set(d1, 'c/e/', 3)
data.set(d1, 'c/f', [ 1 ] )
self.assertEqual(d1, d12)
data.set(d1, 'c/f[0]', 2 )
self.assertEqual(d1, d13)
data.set(d1, 'c/f[0]', { 'g': [ 1, 2] } )
self.assertEqual(d1, d14)
data.set(d1, 'c/f[0]/g[1]', 3)
self.assertEqual(d1, d15)
self.assertRaises(data.DataTypeError, data.set, d1, 1, 2)
self.assertRaises(data.DataTypeError, data.set, 1, "", 2)
self.assertRaises(data.DataTypeError, data.set, d1, 'c[1]', 2)
self.assertRaises(data.DataNotFoundError, data.set, d1, 'c/f[5]', 2)
d3 = {}
e3 = data.set(d3, "does/not/exist", 123)
self.assertEqual(d3,
......@@ -125,11 +156,25 @@ class TestData(unittest.TestCase):
{ 'does': { 'not': { 'exist': 123 } } })
def test_unset(self):
d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2 } }
d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': [ 1, 2, 3 ] } }
data.unset(d1, 'a')
data.unset(d1, 'c/d')
data.unset(d1, 'does/not/exist')
self.assertEqual(d1, { 'b': 1, 'c': { 'e': 2 } })
self.assertEqual(d1, { 'b': 1, 'c': { 'e': [ 1, 2, 3 ] } })
data.unset(d1, 'c/e[0]')
self.assertEqual(d1, { 'b': 1, 'c': { 'e': [ 2, 3 ] } })
data.unset(d1, 'c/e[1]')
self.assertEqual(d1, { 'b': 1, 'c': { 'e': [ 2 ] } })
# index 1 should now be out of range
self.assertRaises(data.DataNotFoundError, data.unset, d1, 'c/e[1]')
d2 = { 'a': [ { 'b': [ 1, 2 ] } ] }
data.unset(d2, 'a[0]/b[1]')
self.assertEqual(d2, { 'a': [ { 'b': [ 1 ] } ] })
d3 = { 'a': [ [ 1, 2 ] ] }
data.set(d3, "a[0][1]", 3)
self.assertEqual(d3, { 'a': [ [ 1, 3 ] ] })
data.unset(d3, 'a[0][1]')
self.assertEqual(d3, { 'a': [ [ 1 ] ] })
def test_find_no_exc(self):
d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2, 'more': { 'data': 'here' } } }
......
......@@ -395,19 +395,31 @@ class UIModuleCCSession(MultiConfigData):
a DataTypeError if the value at the identifier is not a list,
or if the given value_str does not match the list_item_spec
"""
if identifier == "":
identifier = value_str
value_str = None
module_spec = self.find_spec_part(identifier)
if (type(module_spec) != dict or "list_item_spec" not in module_spec):
raise isc.cc.data.DataNotFoundError(str(identifier) + " is not a list")
value = isc.cc.data.parse_value_str(value_str)
isc.config.config_data.check_type(module_spec, [value])
cur_list, status = self.get_value(identifier)
#if not cur_list:
# cur_list = isc.cc.data.find_no_exc(self.config.data, identifier)
if not cur_list:
cur_list = []
if value in cur_list:
cur_list.remove(value)
self.set_value(identifier, cur_list)
if value_str is None:
# we are directly removing an list index
id, list_indices = isc.cc.data.split_identifier_list_indices(identifier)
if list_indices is None:
raise DataTypeError("identifier in remove_value() does not contain a list index, and no value to remove")
else:
self.set_value(identifier, None)
else:
value = isc.cc.data.parse_value_str(value_str)
isc.config.config_data.check_type(module_spec, [value])
cur_list, status = self.get_value(identifier)
#if not cur_list:
# cur_list = isc.cc.data.find_no_exc(self.config.data, identifier)
if not cur_list:
cur_list = []
if value in cur_list:
cur_list.remove(value)
self.set_value(identifier, cur_list)
def commit(self):
"""Commit all local changes, send them through b10-cmdctl to
......
......@@ -108,13 +108,11 @@ def find_spec_part(element, identifier):
id_parts = identifier.split("/")
id_parts[:] = (value for value in id_parts if value != "")
cur_el = element
for id in id_parts:
#for id in id_parts:
for id_part in id_parts:
# strip list selector part
# don't need it for the spec part, so just drop it
i = id.find('[')
e = id.find(']')
if i >= 0 and e > i + 1:
id = id[:i]
id, list_indices = isc.cc.data.split_identifier_list_indices(id_part)
if type(cur_el) == dict and 'map_item_spec' in cur_el.keys():
found = False
for cur_el_item in cur_el['map_item_spec']:
......@@ -127,20 +125,12 @@ def find_spec_part(element, identifier):
found = False
for cur_el_item in cur_el:
if cur_el_item['item_name'] == id:
#print("[XX] full list item:")
#print(cur_el_item)
#if 'list_item_spec' in cur_el_item:
# cur_el = cur_el_item['list_item_spec']
#else:
cur_el = cur_el_item
found = True
if not found:
raise isc.cc.data.DataNotFoundError(id + " in " + str(cur_el))
else:
raise isc.cc.data.DataNotFoundError("Not a correct config specification")
print("[XX] Returning: ")
print(cur_el)
print("[XX] end")
return cur_el
def spec_name_list(spec, prefix="", recurse=False):
......@@ -242,6 +232,20 @@ class ConfigData:
result[item] = value
return result
# should we just make a class for these?
def _create_value_map_entry(name, type, value, status = None):
entry = {}
entry['name'] = name
entry['type'] = type
entry['value'] = value
entry['modified'] = False
entry['default'] = False
if status == MultiConfigData.LOCAL:
entry['modified'] = True
if status == MultiConfigData.DEFAULT:
entry['default'] = True
return entry
class MultiConfigData:
"""This class stores the module specs, current non-default
configuration values and 'local' (uncommitted) changes for
......@@ -286,7 +290,7 @@ class MultiConfigData:
identifier (up to the first /) is interpreted as the module
name. Returns None if not found, or if identifier is not a
string."""
if type(identifier) != str:
if type(identifier) != str or identifier == "":
return None
if identifier[0] == '/':
identifier = identifier[1:]
......@@ -350,16 +354,16 @@ class MultiConfigData:
try:
spec = find_spec_part(self._specifications[module].get_config_spec(), id)
if 'item_default' in spec:
i = id.find('[')
e = id.find(']')
if i >= 0 and e > i + 1 \
and type(spec['item_default']) == list:
default_list = spec['item_default']
index = int(id[i + 1:e])
if index < len(default_list):
return default_list[index]
else:
return None
id, list_indices = isc.cc.data.split_identifier_list_indices(id)
if list_indices is not None and \
type(spec['item_default']) == list:
if len(list_indices) == 1:
default_list = spec['item_default']
index = list_indices[0]
if index < len(default_list):
return default_list[index]
else:
return None
else:
return spec['item_default']
else:
......@@ -377,20 +381,15 @@ class MultiConfigData:
set DEFAULT if the argument 'default' is False (default
defaults to True)"""
value = self.get_local_value(identifier)
print("[XX] mcd get_value() for: " + identifier)
print("[XX] mcd get_value() local: " + str(value))
if value != None:
return value, self.LOCAL
value = self.get_current_value(identifier)
print("[XX] mcd get_value() current: " + str(value))
if value != None:
return value, self.CURRENT
if default:
value = self.get_default_value(identifier)
print("[XX] mcd get_value() default: " + str(value))
if value != None:
return value, self.DEFAULT
print("[XX] mcd get_value() nothing found")
return None, self.NONE
def get_value_maps(self, identifier = None):
......@@ -407,12 +406,7 @@ class MultiConfigData:
if not identifier:
# No identifier, so we need the list of current modules
for module in self._specifications.keys():
entry = {}
entry['name'] = module
entry['type'] = 'module'
entry['value'] = None
entry['modified'] = False
entry['default'] = False
entry = _create_value_map_entry(module, 'module', None)
result.append(entry)
else:
if identifier[0] == '/':
......@@ -423,71 +417,37 @@ class MultiConfigData:
spec_part = find_spec_part(spec.get_config_spec(), id)
if type(spec_part) == list:
for item in spec_part:
entry = {}
entry['name'] = item['item_name']
entry['type'] = item['item_type']
print("[XX] GET VALUE FOR: " + str("/" + identifier + "/" + item['item_name']))
value, status = self.get_value("/" + identifier + "/" + item['item_name'])
entry['value'] = value
if status == self.LOCAL:
entry['modified'] = True
else:
entry['modified'] = False
if status == self.DEFAULT:
entry['default'] = False
else:
entry['default'] = False
value, status = self.get_value("/" + identifier\
+ "/" + item['item_name'])
entry = _create_value_map_entry(item['item_name'],
item['item_type'],
value, status)
result.append(entry)
elif type(spec_part) == dict:
item = spec_part
if item['item_type'] == 'list':
li_spec = item['list_item_spec']
print("[XX] GET VALUE FOR: " + str("/" + identifier))
value, status = self.get_value("/" + identifier)
print("[XX] ITEM_LIST: " + str(value))
if type(value) == list:
for list_value in value:
result_part2 = {}
result_part2['name'] = li_spec['item_name']
result_part2['value'] = list_value
result_part2['type'] = li_spec['item_type']
result_part2['default'] = False
result_part2['modified'] = False
result_part2 = _create_value_map_entry(
li_spec['item_name'],
li_spec['item_type'],
list_value)
result.append(result_part2)
elif value is not None:
entry = {}
entry['name'] = li_spec['item_name']
entry['type'] = li_spec['item_type']
entry['value'] = value
if status == self.LOCAL:
entry['modified'] = True
else:
entry['modified'] = False
if status == self.DEFAULT:
entry['default'] = False
else:
entry['default'] = False
entry = _create_value_map_entry(
li_spec['item_name'],
li_spec['item_type'],
value, status)
result.append(entry)
else:
#value, status = self.get_value("/" + identifier + "/" + item['item_name'])
print("[XX] GET VALUE FOR: " + str("/" + identifier))
# The type of the config data is a list,
# so we do not want to have a default if it's
# out of range
value, status = self.get_value("/" + identifier, False)
value, status = self.get_value("/" + identifier)
if value is not None:
entry = {}
entry['name'] = item['item_name']
entry['type'] = item['item_type']
entry['value'] = value
if status == self.LOCAL:
entry['modified'] = True
else:
entry['modified'] = False
if status == self.DEFAULT:
entry['default'] = False
else:
entry['default'] = False
entry = _create_value_map_entry(
item['item_name'],
item['item_type'],
value, status)
result.append(entry)
return result
......@@ -496,15 +456,28 @@ class MultiConfigData:
there is a specification for the given identifier, the type
is checked."""
spec_part = self.find_spec_part(identifier)
print("[XX] SPEC PART FOR " + identifier + ": ")
print(spec_part)
if spec_part != None:
i = identifier.find('[')
e = identifier.find(']')
if i >= 0 and e > i and spec_part['item_type'] == 'list':
if spec_part is not None and value is not None:
id, list_indices = isc.cc.data.split_identifier_list_indices(identifier)
if list_indices is not None \
and spec_part['item_type'] == 'list':
spec_part = spec_part['list_item_spec']
check_type(spec_part, value)
# TODO: get the local list to value
# Since we do not support list diffs (yet?), we need to
# copy the currently set list of items to _local_changes
# if we want to modify an element in there
# (for any list indices specified in the full identifier)
id_parts = isc.cc.data.split_identifier(identifier)
cur_id_part = '/'
for id_part in id_parts:
id, list_indices = isc.cc.data.split_identifier_list_indices(id_part)
if list_indices is not None:
cur_list, status = self.get_value(cur_id_part + id)
if status != MultiConfigData.LOCAL:
isc.cc.data.set(self._local_changes,
cur_id_part + id,
cur_list)
cur_id_part = cur_id_part + id_part + "/"
isc.cc.data.set(self._local_changes, identifier, value)
def get_config_item_list(self, identifier = None, recurse = False):
......
......@@ -637,6 +637,8 @@ class TestUIModuleCCSession(unittest.TestCase):
self.assertEqual({'Spec2': {'item5': ['foo']}}, uccs._local_changes)
uccs.add_value("Spec2/item5", "foo")
self.assertEqual({'Spec2': {'item5': ['foo']}}, uccs._local_changes)
uccs.remove_value("Spec2/item5[0]", None)
self.assertEqual({'Spec2': {'item5': []}}, uccs._local_changes)
def test_commit(self):
fake_conn = fakeUIConn()
......
......@@ -353,20 +353,34 @@ class TestMultiConfigData(unittest.TestCase):
module_spec = isc.config.module_spec_from_file(self.data_path + os.sep + "spec2.spec")
self.mcd.set_specification(module_spec)
self.mcd.set_value("Spec2/item1", 2)
value,status = self.mcd.get_value("Spec2/item1")
value, status = self.mcd.get_value("Spec2/item1")
self.assertEqual(2, value)
self.assertEqual(MultiConfigData.LOCAL, status)
value,status = self.mcd.get_value("Spec2/item2")
value, status = self.mcd.get_value("Spec2/item2")
self.assertEqual(1.1, value)
self.assertEqual(MultiConfigData.DEFAULT, status)
self.mcd._current_config = { "Spec2": { "item3": False } }
value,status = self.mcd.get_value("Spec2/item3")
value, status = self.mcd.get_value("Spec2/item3")
self.assertEqual(False, value)
self.assertEqual(MultiConfigData.CURRENT, status)
value,status = self.mcd.get_value("Spec2/no_such_item")
value, status = self.mcd.get_value("Spec2/no_such_item")
self.assertEqual(None, value)
self.assertEqual(MultiConfigData.NONE, status)
value, status = self.mcd.get_value("Spec2/item5[0]")
self.assertEqual("a", value)
self.assertEqual(MultiConfigData.DEFAULT, status)
value, status = self.mcd.get_value("Spec2/item5[0]", False)
self.assertEqual(None, value)
self.assertEqual(MultiConfigData.NONE, status)
def test_get_value_maps(self):