#!/usr/bin/python2

# Copyright 2017 Intel Corporation, All rights reserved
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
#    may be used to endorse or promote products derived from this software without
#    specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from __future__ import print_function
import sys
import argparse
import subprocess
from xml.etree import ElementTree as ET
import xml.dom.minidom


def error(*argv):
    print("Error: ", *argv, sep="")
    sys.exit(1)


def mapvcpu(args, p):
    """Maps real processor ID to virtual processor ID."""

    omitcpu = args.omit_cores
    stride = args.core_num
    m = p // args.threads
    n = p % args.threads
    return omitcpu + stride * n + m


def writeout(args, domain):
    """Writes configuration to file or stdout."""

    xstring = ET.tostring(domain, encoding='utf-8', method='xml')
    pxml = '\n'.join([line for line in xml.dom.minidom.parseString(
        xstring).toprettyxml(indent="  ").split('\n') if line.strip()])

    if args.output_file:
        with open(args.output_file, "wt") as f:
            f.write(pxml)
    else:
        print(pxml)
    pass


def writewholeconfig(args):
    """Creates exemplary KVM configuration for virsh"""

    domain = ET.Element("domain", {"type": "kvm"})
    name = ET.SubElement(domain, "name").text = "vmm"
    memory = ET.SubElement(
        domain, "memory", {"unit": "GiB"}).text = str(args.total_mem_size)
    vcpu = ET.SubElement(
        domain, "vcpu", {"placement": "static"}).text = str(args.vcpu)
    cpu = ET.SubElement(domain, "cpu", {"mode": "host-passthrough"})
    numa = ET.SubElement(cpu, "numa")
    for cell in args.numa:
        ET.SubElement(numa, "cell", {"id": str(cell["id"]),
                                     "cpus": cell["cpus"],
                                     "memory": str(cell["mem"]),
                                     "unit": "GiB"})
    topology = ET.SubElement(cpu, "topology", {"sockets": "1",
                                               "cores": str(args.vm_cores),
                                               "threads": str(args.threads)})
    numatune = ET.SubElement(domain, "numatune")

    for n in range(2):
        ET.SubElement(numatune, "memnode", {"cellid": str(
            n), "mode": "strict", "nodeset": str(n)})
    os = ET.SubElement(domain, "os")
    ostype = ET.SubElement(
        os, "type", {"arch": "x86_64", "machine": "pc"}).text = "hvm"
    boot = ET.SubElement(os, "boot", {"dev": "hd"})
    features = ET.SubElement(domain, "features")
    for f in ["acpi", "apic"]:
        ET.SubElement(features, f)

    devices = ET.SubElement(domain, "devices")
    disk = ET.SubElement(devices, "disk", {"type": "file", "device": "disk"})
    driver = ET.SubElement(disk, "driver", {"name": "qemu", "type": "qcow2"})
    source = ET.SubElement(disk, "source", {"file": args.vm_file})
    target = ET.SubElement(disk, "target", {"bus": "ide", "dev": "hda"})
    address = ET.SubElement(disk, "address", {"type": "drive"})
    interface = ET.SubElement(devices, "interface", {"type": "network"})
    source = ET.SubElement(interface, "source", {"network": "default"})
    graphics = ET.SubElement(devices, "graphics", {"type": "vnc"})

    cputune = ET.SubElement(domain, "cputune")
    for p in range(args.vcpu):
        ET.SubElement(cputune, "vcpupin", {"vcpu": str(
            p), "cpuset": str(mapvcpu(args, p))})

    writeout(args, domain)


def updateconfig(args):
    """Updates configuration with vcpu pinning."""

    with open(args.input_file, "rb") as f:
        tree = ET.parse(f)

    domain = tree.find('.')
    assert domain is not None
    assert domain.tag == 'domain'

    # DELETE cputune
    cputune = tree.find('.//cputune')
    if cputune:
        domain.remove(cputune)

    domain.append(ET.Comment("BEGIN [%s]" % args.toolname))

    cputune = ET.SubElement(domain, "cputune")
    for p in range(args.vcpu):
        ET.SubElement(cputune, "vcpupin", {"vcpu": str(
            p), "cpuset": str(mapvcpu(args, p))})

    domain.append(ET.Comment("END [%s]" % args.toolname))

    writeout(args, domain)


def doit(args):
    """Writes entire configuration or update it with vpu pinning
       if input configuration file was provided"""

    if args.input_file:
        updateconfig(args)
    else:
        writewholeconfig(args)


def vcpusstr(first, count):
    return "%d-%d" % (first, first + count - 1)


def gt0(string):
    """Argument validator. Input argument must be GT than 0."""

    try:
        value = int(string)
        if value > 0:
            return value
    except:
        pass
    msg = "%r is not a positive number" % string
    raise argparse.ArgumentTypeError(msg)


def parseargs(cmd_args):
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input-file",
                        help="Input file path")
    parser.add_argument("-o", "--output-file",
                        help="Output file path")

    parser.add_argument("--core-num",
                        help="Number of cores.",
                        required=True,
                        type=gt0)
    parser.add_argument("--total-mem-size",
                        help="Total memory size assigned to virt guest (DRM+MCRDAM)",
                        required=True,
                        type=gt0)

    parser.add_argument("--omit-cores",
                        help="Number of first cores designated to host system. (default=2)",
                        default=2, type=gt0)

    parser.add_argument("--vm-cores",
                        help="Number of cores designated to vm. (default=61)",
                        default=61, type=gt0)

    parser.add_argument("-f", "--vm-file",
                        help="The path to the VM image file")

    parser.add_argument('--numanode1-ram',
                        help="Extra NUMA node ram size in GB. (MCDRAM, default = 16)",
                        type=gt0, default=16,
                        dest="node1_ram")

    parser.add_argument('--numanode1-vcpus',
                        help="Extra NUMA node vcpu number. (default = 4)",
                        type=gt0, default=4,
                        dest="node1_vcpu")

    args = parser.parse_args(cmd_args)

    args.toolname = sys.argv[0].split("/")[-1]

    # args validation
    if args.core_num < 1:
        error("Invalid number of available cores", args.core_num)

    if args.input_file is None:
        if args.vm_file is None:
            error("VM image file is not specified")

    required_cores = args.vm_cores + args.omit_cores
    if required_cores > args.core_num:
        error("Number of required cores (%d) is bigger than available in system (%d)" % (
            required_cores, args.core_num))

    args.threads = 4  # Xeon PHI const
    args.vcpu = args.vm_cores * args.threads
    args.node0_vcpu = args.vcpu - args.node1_vcpu

    if args.node1_ram == 0 and args.node1_vcpu == 0:
        args.numa = [{"id": 0, "mem": args.total_mem_size - args.node1_ram,
                      "cpus": vcpusstr(0, args.node0_vcpu)}]
    else:
        args.numa = [{"id": 0, "mem": args.total_mem_size - args.node1_ram, "cpus": vcpusstr(0, args.node0_vcpu)},
                     {"id": 1, "mem": args.node1_ram, "cpus": vcpusstr(args.node0_vcpu, args.node1_vcpu)}]

    return args

if __name__ == "__main__":
    sys.exit(doit(parseargs(sys.argv[1:])))
