0001try:
0002    import doctest
0003    doctest.OutputChecker
0004except AttributeError:
0005    import util.doctest24 as doctest
0006try:
0007    import xml.etree.ElementTree as ET
0008except ImportError:
0009    import elementtree.ElementTree as ET
0010from xml.parsers.expat import ExpatError as XMLParseError
0011
0012RealOutputChecker = doctest.OutputChecker
0013
0014def debug(*msg):
0015    import sys
0016    print >> sys.stderr, ' '.join(map(str, msg))
0017
0018class HTMLOutputChecker(RealOutputChecker):
0019
0020    def check_output(self, want, got, optionflags):
0021        normal = RealOutputChecker.check_output(self, want, got, optionflags)
0022        if normal or not got:
0023            return normal
0024        try:
0025            want_xml = make_xml(want)
0026        except XMLParseError, e:
0027            pass
0028        else:
0029            try:
0030                got_xml = make_xml(got)
0031            except XMLParseError:
0032                pass
0033            else:
0034                if xml_compare(want_xml, got_xml):
0035                    return True
0036        return False
0037
0038    def output_difference(self, example, got, optionflags):
0039        actual = RealOutputChecker.output_difference(
0040            self, example, got, optionflags)
0041        want_xml = got_xml = None
0042        try:
0043            want_xml = make_xml(example.want)
0044            want_norm = make_string(want_xml)
0045        except XMLParseError, e:
0046            if example.want.startswith('<'):
0047                want_norm = '(bad XML: %s)' % e
0048                #  '<xml>%s</xml>' % example.want
0049            else:
0050                return actual
0051        try:
0052            got_xml = make_xml(got)
0053            got_norm = make_string(got_xml)
0054        except XMLParseError, e:
0055            if example.want.startswith('<'):
0056                got_norm = '(bad XML: %s)' % e
0057            else:
0058                return actual
0059        s = '%s\nXML Wanted: %s\nXML Got   : %s\n' % (
0060            actual, want_norm, got_norm)
0061        if got_xml and want_xml:
0062            result = []
0063            xml_compare(want_xml, got_xml, result.append)
0064            s += 'Difference report:\n%s\n' % '\n'.join(result)
0065        return s
0066
0067def xml_compare(x1, x2, reporter=None):
0068    if x1.tag != x2.tag:
0069        if reporter:
0070            reporter('Tags do not match: %s and %s' % (x1.tag, x2.tag))
0071        return False
0072    for name, value in x1.attrib.items():
0073        if x2.attrib.get(name) != value:
0074            if reporter:
0075                reporter('Attributes do not match: %s=%r, %s=%r'
0076                         % (name, value, name, x2.attrib.get(name)))
0077            return False
0078    for name in x2.attrib.keys():
0079        if not x1.attrib.has_key(name):
0080            if reporter:
0081                reporter('x2 has an attribute x1 is missing: %s'
0082                         % name)
0083            return False
0084    if not text_compare(x1.text, x2.text):
0085        if reporter:
0086            reporter('text: %r != %r' % (x1.text, x2.text))
0087        return False
0088    if not text_compare(x1.tail, x2.tail):
0089        if reporter:
0090            reporter('tail: %r != %r' % (x1.tail, x2.tail))
0091        return False
0092    cl1 = x1.getchildren()
0093    cl2 = x2.getchildren()
0094    if len(cl1) != len(cl2):
0095        if reporter:
0096            reporter('children length differs, %i != %i'
0097                     % (len(cl1), len(cl2)))
0098        return False
0099    i = 0
0100    for c1, c2 in zip(cl1, cl2):
0101        i += 1
0102        if not xml_compare(c1, c2, reporter=reporter):
0103            if reporter:
0104                reporter('children %i do not match: %s'
0105                         % (i, c1.tag))
0106            return False
0107    return True
0108
0109def text_compare(t1, t2):
0110    if not t1 and not t2:
0111        return True
0112    if t1 == '*' or t2 == '*':
0113        return True
0114    return (t1 or '').strip() == (t2 or '').strip()
0115
0116def make_xml(s):
0117    return ET.XML('<xml>%s</xml>' % s)
0118
0119def make_string(xml):
0120    if isinstance(xml, (str, unicode)):
0121        xml = make_xml(xml)
0122    s = ET.tostring(xml)
0123    if s == '<xml />':
0124        return ''
0125    assert s.startswith('<xml>') and s.endswith('</xml>'), repr(s)
0126    return s[5:-6]
0127
0128def install():
0129    doctest.OutputChecker = HTMLOutputChecker