Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add namespace support #9

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions tests/test_xmljson.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
import unittest

from collections import OrderedDict as Dict
from lxml.etree import tostring as tostring, fromstring
from lxml.etree import tostring as tostring, fromstring, ElementTree
from lxml.doctestcompare import LXMLOutputChecker
import lxml.html
import lxml.etree
import xml.etree.cElementTree
import xmljson


# For Python 3, decode byte strings as UTF-8
if sys.version_info[0] == 3:
def decode(s):
Expand Down Expand Up @@ -54,7 +55,24 @@ def compare(jsonstring, xmlstring):
first = json.loads(jsonstring, object_pairs_hook=Dict)
second = conv.data(fromstring(xmlstring))
self.assertEqual(first, second)
return compare

def check_nsmap(self, conv):
def compare(xmlstring):
result = conv.data(fromstring(xmlstring))
root = conv.etree(result)
t1 = fromstring(xmlstring)
t2 = root[0]
try:
t1.nsmap
except:
ns = {'charlie': "http://some-other-namespace"}

r1 = t1.find('charlie:joe', ns)
r2 = t2.find('charlie:joe', ns)
self.assertEqual(r1.tag, r2.tag)
return
self.assertEqual(t1.nsmap, t2.nsmap)
return compare


Expand Down Expand Up @@ -163,9 +181,9 @@ def test_data(self):
'<alice charlie="david">bob</alice>')

def test_xml_namespace(self):
'XML namespaces are not yet implemented'
with self.assertRaises(ValueError):
xmljson.badgerfish.etree({'alice': {'@xmlns': {'$': 'http:\/\/some-namespace'}}})
'Checks nsmap attribute of root tag'
eq = self.check_nsmap(xmljson.badgerfish)
eq('<alice xmlns="http://some-namespace" xmlns:charlie="http://some-other-namespace"><charlie:joe>bob</charlie:joe></alice>')

def test_custom_dict(self):
'Conversion to dict uses OrderedDict'
Expand Down
172 changes: 161 additions & 11 deletions xmljson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import sys
from collections import Counter, OrderedDict
from io import BytesIO

try:
from lxml.etree import Element
from lxml.etree import Element, iterparse, ElementTree, tostring
except ImportError:
from xml.etree.cElementTree import Element
from xml.etree.cElementTree import Element, iterparse, ElementTree

__author__ = 'S Anand'
__email__ = '[email protected]'
Expand All @@ -19,7 +21,7 @@

class XMLData(object):
def __init__(self, xml_fromstring=True, xml_tostring=True, element=None, dict_type=None,
list_type=None, attr_prefix=None, text_content=None, simple_text=False):
list_type=None, attr_prefix=None, text_content=None, ns_name=None, simple_text=False):
# xml_fromstring == False(y) => '1' -> '1'
# xml_fromstring == True => '1' -> 1
# xml_fromstring == fn => '1' -> fn(1)
Expand All @@ -44,13 +46,24 @@ def __init__(self, xml_fromstring=True, xml_tostring=True, element=None, dict_ty
# simple_text == True => '<x>a</x>' = {'x': 'a'}
self.simple_text = simple_text

self.ns_name = ns_name
try:
elem = Element("html", nsmap={None: 'test'})
elem.nsmap
self.lxml_lib = True
self.root_count = 0
except:
self.lxml_lib = False

@staticmethod
def _tostring(value):
'Convert value to XML compatible string'
if value is True:
value = 'true'
elif value is False:
value = 'false'
else:
value = str(value)
return unicode(value) # noqa: convert to whatever native unicode repr

@staticmethod
Expand All @@ -64,7 +77,10 @@ def _fromstring(value):
elif std_value == 'false':
return False
try:
return int(std_value)
if std_value.startswith('0'):
return std_value
else:
return int(std_value)
except ValueError:
pass
try:
Expand All @@ -76,6 +92,7 @@ def _fromstring(value):
def etree(self, data, root=None):
'Convert data structure into a list of etree.Element'
result = self.list() if root is None else root

if isinstance(data, (self.dict, dict)):
for key, value in data.items():
value_is_list = isinstance(value, (self.list, list))
Expand All @@ -88,7 +105,21 @@ def etree(self, data, root=None):
key = key.lstrip(self.attr_prefix)
# @xmlns: {$: xxx, svg: yyy} becomes xmlns="xxx" xmlns:svg="yyy"
if value_is_dict:
raise ValueError('XML namespaces not yet supported')
if self.lxml_lib:
if key == self.ns_name.lstrip(self.attr_prefix):
# Actually nothing to do here
pass
else:
for k in value.keys():
if len(k) > 0:
if k == self.text_content:
k_default = 'ns0'
self.ns_counter += 1
result.set('xmlns:' + k_default, self._tostring(value[k]))
else:
result.set('xmlns:' + k, self._tostring(value[k]))
else:
result.set('xmlns' + k, self._tostring(value[k]))
else:
result.set(key, self._tostring(value))
continue
Expand All @@ -105,8 +136,34 @@ def etree(self, data, root=None):
# Add other keys as one or more children
values = value if value_is_list else [value]
for value in values:
elem = self.element(key)
result.append(elem)
if value_is_dict:
# Add namespaces to nodes if @xmlns present
if self.ns_name in value.keys() and self.lxml_lib:
NS_MAP = self.dict()
for k in value[self.ns_name]:
prefix = k
if prefix == self.text_content:
prefix = 'ns0'
uri = value[self.ns_name][k]

if ':' in key:
prefix, tag = key.split(':')
key = tag

NS_MAP[prefix] = uri
continue

if len(value[self.ns_name]) > 1:
uri = ''
elem = self.element('{0}{1}'.format('{' + uri + '}', key), nsmap=NS_MAP)
result.append(elem)
else:
elem = self.element(key)
result.append(elem)
else:
elem = self.element(key)
result.append(elem)

# Treat scalars as text content, not children (Parker)
if not isinstance(value, (self.dict, dict, self.list, list)):
if self.text_content:
Expand All @@ -122,31 +179,124 @@ def etree(self, data, root=None):
def data(self, root):
'Convert etree.Element into a dictionary'
value = self.dict()
root = XMLData._process_ns(self, element=root)

children = [node for node in root if isinstance(node.tag, basestring)]
for attr, attrval in root.attrib.items():
attr = attr if self.attr_prefix is None else self.attr_prefix + attr
value[attr] = self._fromstring(attrval)

# form lxml.Element with namespaces if present
if self.lxml_lib:
if root.tag.startswith('{'):
uri, root.tag = root.tag.split('}')
uri = uri.lstrip('{')
nsmap = root.nsmap
value[self.ns_name] = {}

# pushing namespaces to dic; Filtering namespaces by prefix except root node
for key in nsmap.keys():
if self.root_count == 0:
value[self.ns_name].update({key: nsmap[key]})
else:
if nsmap[key] == uri:
value[self.ns_name].update({key: nsmap[key]})
self.root_count += 1
else:
for attr, attrval in root.attrib.items():
attr = attr if self.attr_prefix is None else self.attr_prefix + attr
value[attr] = self._fromstring(attrval)
else:
for attr, attrval in root.attrib.items():
attr = attr if self.attr_prefix is None else self.attr_prefix + attr

if self.attr_prefix:
if self.ns_name in attr:
if not attr.endswith(':'):
prefix = attr.split(':')[1]
value[attr.replace(prefix, '')] = {prefix: self._fromstring(attrval)}
else:
prefix = attr.split(':')[1]
value['@xmlns'] = {prefix: self._fromstring(attrval)}
else:
value[attr] = self._fromstring(attrval)
else:
value[attr] = self._fromstring(attrval)

if root.text and self.text_content is not None:
text = root.text.strip()
if text:
if self.simple_text and len(children) == len(root.attrib) == 0:
value = self._fromstring(text)
else:
value[self.text_content] = self._fromstring(text)

count = Counter(child.tag for child in children)
for child in children:
child = XMLData._process_ns(self, child)
if count[child.tag] == 1:
value.update(self.data(child))
else:
result = value.setdefault(child.tag, self.list())
result += self.data(child).values()
return self.dict([(root.tag, value)])

@staticmethod
def _process_ns(cls, element):
if element.tag.startswith('{'):
if any([True if k.split(':')[0] == 'xmlns' else False for k in element.attrib.keys()]):
revers_attr = {v:k for k,v in element.attrib.items()}

end_prefix = element.tag.find('}')
uri = element.tag[:end_prefix+1]
key_prefix = revers_attr[uri.strip('{}')]
prefix = key_prefix.split(':')[1]

if len(prefix) > 1:
element.tag = element.tag.replace(uri, prefix + ':')
else:
element.tag = element.tag.replace(uri, '')

# trick to determine if given element is root element
try:
_ = element.getroot()
element.attrib.pop(key_prefix, None)
except:
pass
else:
ns_keys = [k if k.split(':')[0] == 'xmlns' else None for k in element.attrib.keys()]
for key in ns_keys:
if key:
element.attrib.pop(key, None)
return element

@classmethod
def parse_nsmap(cls, file):
# Parse given file-like xml object for namespaces
if isinstance(file, (str)):
file = BytesIO(file.encode('utf-8'))

events = "start", "start-ns", "end-ns"
root = None
ns_map = []

for event, elem in iterparse(file, events):
if event == "start-ns":
ns_map.append(elem)
elif event == "end-ns":
ns_map.pop()
elif event == "start":
if root is None:
root = elem
if ns_map:
for ns in ns_map:
ns_prefix = ns[0]
ns_uri = ns[1]
elem.set('xmlns:{}'.format(ns_prefix), ns_uri)
return ElementTree(root).getroot()


class BadgerFish(XMLData):
'Converts between XML and data using the BadgerFish convention'
def __init__(self, **kwargs):
super(BadgerFish, self).__init__(attr_prefix='@', text_content='$', **kwargs)
super(BadgerFish, self).__init__(attr_prefix='@', text_content='$', ns_name='@xmlns', **kwargs)


class GData(XMLData):
Expand Down