#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Plot a graph from one or more GTK+ Theme Torturer log files.

Written by Xan López, based on code by Federico Mena-Quintero.
Some small modifications by Marius Gedminas.

TODO: what's the licence?
"""

import optparse
import os
import xml.dom.minidom
import cairo

IMAGE_SIZE = (750, 400)     # size of output image
MARGIN = 25                 # basic margin on all sides
LEGEND_SQUARE_SIZE = 10     # size of coloured squares in legend
LEGEND_SPACING = 5          # spacing between label and coloured square
LEGEND_PADDING = 20         # spacing between legend items
LEGEND_V_SPACE = 10         # spacing between title and its underline
LEFT_MARGIN = 50            # margin left of the main plot area
TICK_WIDTH = 5              # size of axis scale marks
COLUMNS_TOP_MARGIN = 55     # margin top of the main plot area
COLUMN_WIDTH = 20           # width of a single bar
COLUMN_PADDING = 30         # padding between columns

FONT_FACE = "Bitstream Vera Sans"
FONT_SIZE = 10
TITLE_FONT_SIZE = 15

palette = [
    (0.98, 0.91, 0.31), # Tango Butter 1
    (0.94, 0.16, 0.16), # Scarlet Red 1
    (0.45, 0.62, 0.81), # Sky Blue 1
    (0.45, 0.82, 0.09), # Chameleon 2
    (0.5, 0.5, 0.5)]


def flatten(list_of_lists):
    """Flatten a list of lists.

    Example:

        >>> flatten([[1, 2], [3, 4]])
        [1, 2, 3, 4]

    """
    result = []
    for a_list in list_of_lists:
        result.extend(a_list)
    return result


class TorturePlot:
    """Plot maker.

    Usage example::

        dataset = {'Button': {'boot::Expose': [1.5, 0.9],
                              'boot::Map': [0.2, 0.3]},
                   'Label': {'boot::Expose': [0.3, 0.2],
                             'boot::Map': [0.1, 0.1]},
                   }
        datanames = ['gtk2.6', 'gtk2.8']
        plot = TorturePlot(dataset, datanames)

        class Options:
            normalize = True
            prefix = "/tmp/my"
        options = Options()
        plot.plotAllWidgets(options)

    This code will create /tmp/myButton.png and /tmp/myLabel.png with the
    following graphs::

           GtkButton   gtk2.6 XX  gtk2.8 YY
           -----------------------------------

        1.500s |-    XX
               |     XX
        1.125s |-    XX
               |     XX YY
        0.750s |-    XX YY
               |     XX YY
        0.375s |-    XX YY
               |     XX YY          XX YY
        0.000s |------------------------------
                  boot::Expose    boot::Map


           GtkLabel    gtk2.6 XX  gtk2.8 YY
           -----------------------------------

        1.500s |-
               |
        1.125s |-
               |
        0.750s |-
               |
        0.375s |-
               |     XX YY          XX YY
        0.000s |------------------------------
                  boot::Expose    boot::Map

    """

    def __init__(self, dataset, datanames):
        """Create a plotter for a dataset.

        ``dataset`` is a nested dictionary structure:
        dataset[widget][property] is a list of timing values.

        ``datanames`` is a list of names describing each timing series.
        """
        # Constants
        self.font_face = FONT_FACE
        self.font_size = FONT_SIZE
        self.title_font_size = TITLE_FONT_SIZE
        self.dataset = dataset
        self.datanames = datanames
        self.background_color = (1.0, 1.0, 1.0)
        self.text_color = (0, 0, 0)
        self.column_width = COLUMN_WIDTH
        self.common_scale = None
        self.colors = self._computeTimingColors()
        self.scaleMarks = self._computeScaleMarks()

    def _computeScaleMarks(self):
        (width, height) = self._computeSize(self.dataset)
        inside_height = height - MARGIN - COLUMNS_TOP_MARGIN
        return [COLUMNS_TOP_MARGIN,
                COLUMNS_TOP_MARGIN + inside_height / 4,
                COLUMNS_TOP_MARGIN + inside_height / 2,
                COLUMNS_TOP_MARGIN + inside_height * 3 / 4,
                COLUMNS_TOP_MARGIN + inside_height]

    def _computeTimingColors(self):
        return [palette[i % len(palette)]
                for i in range(len(self.datanames))]

    def _computeSize(self, dataset):
        # Make this dynamic, and do it only once on startup!!
        return IMAGE_SIZE

    def _textWidth(self, text, ctx):
        (x_bearing, y_bearing, width, height,
         x_advance, y_advance) = ctx.text_extents(text)
        return width

    def _plotHeader(self, widget, ctx):
        (width, height) = self._computeSize(self.dataset)
        self.column_max_height = height - COLUMNS_TOP_MARGIN - MARGIN
        # Background
        ctx.set_source_rgb(*self.background_color)
        ctx.rectangle(0, 0, width, height)
        ctx.fill()

        # Widget name
        ctx.set_source_rgb(*self.text_color)
        ctx.move_to(MARGIN, MARGIN)
        ctx.set_font_size(self.title_font_size)
        ctx.show_text(widget)

        # Legend
        widget_width = self._textWidth(widget, ctx)
        x = MARGIN + widget_width + MARGIN
        ctx.move_to(x, MARGIN)
        ctx.set_font_size(self.font_size)
        i = 0
        for set in self.datanames:
            ctx.set_source_rgb(*self.text_color)
            ctx.show_text(set)
            set_width = self._textWidth(set, ctx)
            x += LEGEND_SPACING + set_width
            ctx.rectangle(x, MARGIN - LEGEND_SQUARE_SIZE,
                          LEGEND_SQUARE_SIZE, LEGEND_SQUARE_SIZE)
            ctx.set_source_rgb(*self.colors[i])
            ctx.fill()
            i += 1
            x += LEGEND_PADDING
            ctx.move_to(x, MARGIN)

        ctx.set_source_rgb(*self.text_color)
        ctx.move_to(MARGIN, MARGIN + LEGEND_V_SPACE)
        ctx.line_to(width - MARGIN, MARGIN + LEGEND_V_SPACE)
        ctx.stroke()

    def _calculateCommonScale(self):
        biggest = 0
        for widget_data in self.dataset.values():
            biggest = max(biggest, max(flatten(widget_data.values())))
        return biggest

    def _plotAxis(self, widget, ctx, options):
        (width, height) = self._computeSize(self.dataset)
        ctx.set_line_width(1.0)
        ctx.move_to(LEFT_MARGIN, COLUMNS_TOP_MARGIN - TICK_WIDTH * 2)
        ctx.line_to(LEFT_MARGIN, height - MARGIN)
        ctx.line_to(width - MARGIN, height - MARGIN)
        ctx.stroke()

        # Scale
        if options.normalize:
            if self.common_scale is None:
                self.common_scale = self._calculateCommonScale()
            self.biggest = self.common_scale
        else:
            self.biggest = max(flatten(self.dataset[widget].values()))

        # Draw the scale
        step = self.biggest / (len(self.scaleMarks) - 1)
        mark = self.biggest
        ctx.set_font_size(self.font_size)

        for y in self.scaleMarks:
            if y != height - MARGIN:
                ctx.move_to(LEFT_MARGIN, y)
                ctx.line_to(LEFT_MARGIN + TICK_WIDTH, y)
                ctx.stroke()
            ctx.move_to(5, y + 2)
            ctx.show_text("%.3fs" % mark)
            mark -= step
            if mark < 0:
                mark = 0

    def plotWidget(self, widget, options):
        """Generate a plot for a single widget.

        ``widget`` is the name of a widget in the dataset.

        ``options`` is an options object with these attributes:

            - ``prefix`` -- filename prefix.

            - ``normalize`` -- True if you want the same scale for all plots;
                               False if you want to scale each widget's plot
                               independently.

        The plot will be written to a file named ``widget.png``, with the
        appropriate prefix added.
        """
        (width, height) = self._computeSize(self.dataset)

        surface = cairo.ImageSurface(cairo.FORMAT_RGB24, width, height)
        ctx = cairo.Context(surface)
        ctx.select_font_face(self.font_face, cairo.FONT_SLANT_NORMAL,
                             cairo.FONT_WEIGHT_BOLD)
        ctx.set_font_size(self.font_size)

        # Header
        self._plotHeader(widget, ctx)

        # Axis
        self._plotAxis(widget, ctx, options)

        # Columns
        x = LEFT_MARGIN + COLUMN_PADDING / 2
        items = self.dataset[widget].items()
        items.sort()
        for name, values in items:
            n_events = len(values)
            width = self._textWidth(name, ctx)
            ctx.set_source_rgb(*self.text_color)
            ctx.set_font_size(self.font_size)
            ctx.move_to(x, height - MARGIN/2)
            ctx.show_text(name)

            color = 0
            p = 0
            if self.column_width * n_events > width:
                column_width = width / n_events
            else:
                column_width = self.column_width

            margin = (width - (column_width * n_events)) / 2

            for timing in values:
                column_height = timing * self.column_max_height / self.biggest
                ctx.rectangle(x + p + margin,
                              (height - MARGIN - 2) - column_height,
                              column_width - 5,
                              column_height)
                ctx.set_source_rgb(*self.colors[color])
                color += 1
                ctx.fill()
                p += column_width

            if p < width:
                p = width

            x += width + COLUMN_PADDING

        # Profit!
        surface.write_to_png(options.prefix + widget + ".png")

    def plotAllWidgets(self, options):
        """Generate plots for all widgets in the dataset.

        See ``plotWidget`` for an explanation of the ``options`` argument.
        """
        for widget in self.dataset:
            self.plotWidget(widget, options)


class TorturerParser:
    """Parse a list of XML log files produced by the GTK+ Theme Torturer.

    Usage example:

        parser = TorturerParser(["gtk2.6.xml", "gtk2.8.xml"])
        plot = TorturePlot(parser.result, parser.fileNames)

    """

    def __init__(self, fileList):
        """Create a parser.

        ``fileList`` is a list of filenames for the log files.
        """
        self.fileList = fileList
        self.fileNames = [os.path.splitext(file)[0] for file in self.fileList]

    def parseWidgets(self):
        """Parse the log files.

        The result is placed in an attribute ``result``.  It is a nested
        dictionary structure: result[widget][property] is a list of timing
        values for a given property of a given widget.
        """
        self.result = {}
        for file in self.fileList:
            self._parseWidget(file)

    def _parseWidget(self, file):
        print "Parsing file %s" % file
        dom = xml.dom.minidom.parse(file)

        for widget in dom.getElementsByTagName("widget"):
            name = widget.getAttribute("name")
            if not self.result.has_key(name):
                self.result[name] = {}
            for property in widget.getElementsByTagName("timing"):
                propertyName = (property.getAttribute("name") + "::" +
                                property.getAttribute("subname"))
                if not self.result[name].has_key(propertyName):
                    self.result[name][propertyName] = []
                self.result[name][propertyName].append(
                                            float(property.childNodes[0].data))


def main():
    option_parser = optparse.OptionParser(
        usage="usage: %prog -p prefix [TortureXMLFile1, TortureXMLFile2, ... ]")
    option_parser.add_option("-p",
                             "--prefix", dest="prefix",
                             metavar="FILE",
                             help="Prefix added to all the files generated"
                                  " by the program.")
    option_parser.add_option("-n",
                             "--normalize", dest="normalize",
                             action="store_true", default=False,
                             help="Draw all the images using the same scale")

    options, args = option_parser.parse_args()

    parser = TorturerParser(args)
    parser.parseWidgets()

    plot = TorturePlot(parser.result, parser.fileNames)
    plot.plotAllWidgets(options)


if __name__ == "__main__":
    main()
