#!/usr/libexec/platform-python
#
# Copyright (c) 2017 eBay Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import getopt
import os
import re
import sys
import time

try:
    from ovs.db import idl
    from ovs import jsonrpc
    from ovs.poller import Poller
    from ovs.stream import Stream
    from ovs import dirs
except Exception:
    print("ERROR: Please install the correct Open vSwitch python support")
    print("       libraries (2.14.90).")
    print("       Alternatively, check that your PYTHONPATH is pointing to")
    print("       the correct location.")
    sys.exit(1)


argv0 = sys.argv[0]


def usage():
    print("""\
%(argv0)s:
usage: %(argv0)s < FILE
where FILE is output from ovs-appctl ofproto/trace.

The following options are also available:
  -h, --help                  display this help message
  -V, --version               display version information
  --ovnsb=DATABASE            use DATABASE as southbound DB
  --ovnnb=DATABASE            use DATABASE as northbound DB
  --ovsdb=DATABASE            use DATABASE as OVS DB
  -p, --private-key=FILE      file with private key
  -c, --certificate=FILE      file with certificate for private key
  -C, --ca-cert=FILE          file with peer CA certificate\
""" % {'argv0': argv0})
    sys.exit(0)

print_p = functools.partial(print, '  * ')
print_h = functools.partial(print, '   * ')

def datapath_str(datapath):
    return '"%s" (%s)' % (str(datapath.external_ids.get('name')),
                          datapath.uuid)

def chassis_str(chassis):
    if len(chassis) == 0:
        return ''
    ch = chassis[0]
    return 'chassis-name "%s", chassis-str "%s"' % (ch.name, ch.hostname)

class OVSDB(object):
    STREAM_TIMEOUT_MS = 1000

    @staticmethod
    def wait_for_db_change(idl):
        seq = idl.change_seqno
        stop = time.time() + 10
        while idl.change_seqno == seq and not idl.run():
            poller = Poller()
            idl.wait(poller)
            poller.block()
            if time.time() >= stop:
                raise Exception('Retry Timeout')

    def __init__(self, remote, schema_name):
        self.remote = remote
        self._txn = None
        schema = self._get_schema(schema_name)
        schema.register_all()
        self._idl_conn = idl.Idl(remote, schema)
        OVSDB.wait_for_db_change(self._idl_conn)  # Initial Sync with DB

    def _get_schema(self, schema_name):
        stream = None
        for r in self.remote.split(','):
            error, strm = Stream.open_block(Stream.open(r),
                                            OVSDB.STREAM_TIMEOUT_MS)
            if not error and strm:
                break

            sys.stderr.write('Unable to connect to {}, error: {}\n'.format(r,
                os.strerror(error)))
            strm = None
        if not strm:
            raise Exception('Unable to connect to %s' % self.remote)

        rpc = jsonrpc.Connection(strm)
        req = jsonrpc.Message.create_request('get_schema', [schema_name])
        error, resp = rpc.transact_block(req)
        rpc.close()

        if error or resp.error:
            raise Exception('Unable to retrieve schema.')
        return idl.SchemaHelper(None, resp.result)

    def get_table(self, table_name):
        return self._idl_conn.tables[table_name]

    def _find_rows(self, table_name, find_fn):
        return filter(find_fn, self.get_table(table_name).rows.values())

    def find_rows_by_name(self, table_name, value):
        return self._find_rows(table_name, lambda row: row.name == value)

    def find_rows_by_partial_uuid(self, table_name, value):
        return self._find_rows(table_name,
                               lambda row: str(row.uuid).startswith(value))

    def get_first_record(self, table_name):
        table_rows = self.get_table(table_name).rows.values()
        if len(table_rows) == 0:
            return None
        return next(iter(table_rows))

class CookieHandler(object):
    def __init__(self, db, table):
        self._db = db
        self._table = table

    def print(self, msg):
        print_h(msg)

    def get_records(self, cookie):
        return []

    def print_record(self, record):
        pass

    def print_hint(self, record, db):
        pass

class CookieHandlerByUUUID(CookieHandler):
    def __init__(self, db, table):
        super(CookieHandlerByUUUID, self).__init__(db, table)

    def get_records(self, cookie):
        # Adjust cookie to include leading zeroes if needed.
        cookie = cookie.zfill(8)
        return self._db.find_rows_by_partial_uuid(self._table, cookie)

class ACLHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(ACLHintHandler, self).__init__(ovnnb_db, 'ACL')

    def print_record(self, acl):
        output = 'ACL: %s, priority=%s, ' \
                 'match=(%s), %s' % (acl.direction,
                                     acl.priority,
                                     acl.match.strip('"'),
                                     acl.action)
        if acl.log:
            output += ' (log)'
        print_h(output)

class DHCPOptionsHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(DHCPOptionsHintHandler, self).__init__(ovnnb_db, 'DHCP_Options')

    def print_record(self, dhcp_opt):
        print_h('DHCP Options: cidr %s options (%s)' % (
                    dhcp_opt.cidr, dhcp_opt.options))

class ForwardingGroupHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(ForwardingGroupHintHandler, self).__init__(ovnnb_db,
                                                         'Forwarding_Group')

    def print_record(self, fwd_group):
        print_h('Forwarding Group: name %s vip %s vmac %s liveness %s child ports (%s)' % (
                    fwd_group.name, fwd_group.vip, fwd_group.vmac,
                    fwd_group.liveness, fwd_group.child_port))

class LSPHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(LSPHintHandler, self).__init__(ovnnb_db, 'Logical_Switch_Port')

    def print_record(self, lsp):
        print_h('Logical Switch Port: %s type %s (addresses %s, dynamic addresses %s, security %s' % (
                    lsp.name, lsp.type, lsp.addresses, lsp.dynamic_addresses,
                    lsp.port_security))

class LRPHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(LRPHintHandler, self).__init__(ovnnb_db, 'Logical_Router_Port')

    def print_record(self, lrp):
        print_h('Logical Router Port: %s mac %s networks %s ipv6_ra_configs %s' % (
                    lrp.name, lrp.mac, lrp.networks, lrp.ipv6_ra_configs))

class LRPolicyHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(LRPolicyHandler, self).__init__(ovnnb_db, 'Logical_Router_Policy')

    def print_record(self, policy):
        print_h('Logical Router Policy: priority %s match %s action %s nexthop %s' % (
                    policy.priority, policy.match, policy.action,
                    policy.nexthop))

class LoadBalancerHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(LoadBalancerHintHandler, self).__init__(ovnnb_db, 'Load_Balancer')

    def print_record(self, lb):
        print_h('Load Balancer: %s protocol %s vips %s ip_port_mappings %s' % (
                    lb.name, lb.protocol, lb.vips, lb.ip_port_mappings))

class NATHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(NATHintHandler, self).__init__(ovnnb_db, 'NAT')

    def print_record(self, nat):
        print_h('NAT: external IP %s external_mac %s logical_ip %s logical_port %s type %s' % (
                    nat.external_ip, nat.external_mac, nat.logical_ip,
                    nat.logical_port, nat.type))

class StaticRouteHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(StaticRouteHintHandler, self).__init__(ovnnb_db,
                                                     'Logical_Router_Static_Route')

    def print_record(self, route):
        print_h('Route: %s via %s (port %s), policy=%s' % (
                    route.ip_prefix, route.nexthop, route.output_port,
                    route.policy))

class QoSHintHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db):
        super(QoSHintHandler, self).__init__(ovnnb_db, 'QoS')

    def print_record(self, qos):
        print_h('QoS: priority %s direction %s match %s action %s bandwidth %s' % (
                    qos.priority, qos.direction, qos.match, qos.action,
                    qos.bandwidth))

class LogicalFlowHandler(CookieHandlerByUUUID):
    def __init__(self, ovnnb_db, ovnsb_db):
        super(LogicalFlowHandler, self).__init__(ovnsb_db, 'Logical_Flow')
        self._hint_handlers = [
            ACLHintHandler(ovnnb_db),
            DHCPOptionsHintHandler(ovnnb_db),
            ForwardingGroupHintHandler(ovnnb_db),
            LSPHintHandler(ovnnb_db),
            LRPHintHandler(ovnnb_db),
            LRPolicyHandler(ovnnb_db),
            LoadBalancerHintHandler(ovnnb_db),
            NATHintHandler(ovnnb_db),
            StaticRouteHintHandler(ovnnb_db),
            QoSHintHandler(ovnnb_db),
        ]

    def print_record(self, lflow):
        print_p('Logical datapaths:')
        datapaths = lflow.logical_datapath
        if lflow.logical_dp_group:
            datapaths.extend(lflow.logical_dp_group[0].datapaths)
        for datapath in datapaths:
            print_p('    %s [%s]' % (datapath_str(datapath), lflow.pipeline))
        print_p('Logical flow: table=%s (%s), priority=%s, '
                'match=(%s), actions=(%s)' %
                    (lflow.table_id, lflow.external_ids.get('stage-name'),
                     lflow.priority,
                     str(lflow.match).strip('"'),
                     str(lflow.actions).strip('"')))

    def print_hint(self, lflow, ovnnb_db):
        external_ids = lflow.external_ids
        hint = external_ids.get('stage-hint')
        if not hint:
            return
        for handler in self._hint_handlers:
            for i, record in enumerate(handler.get_records(hint)):
                if i > 0:
                    print_h('[Duplicate uuid hint]')
                handler.print_record(record)

class PortBindingHandler(CookieHandlerByUUUID):
    def __init__(self, ovnsb_db):
        super(PortBindingHandler, self).__init__(ovnsb_db, 'Port_Binding')

    def print_record(self, pb):
        print_p('Logical datapath: %s' % (datapath_str(pb.datapath)))
        print_p('Port Binding: logical_port "%s", tunnel_key %ld, %s' %
                    (pb.logical_port, pb.tunnel_key,
                     chassis_str(pb.chassis)))

class MacBindingHandler(CookieHandlerByUUUID):
    def __init__(self, ovnsb_db):
        super(MacBindingHandler, self).__init__(ovnsb_db, 'MAC_Binding')

    def print_record(self, mb):
        print_p('Logical datapath: %s' % (datapath_str(mb.datapath)))
        print_p('MAC Binding: ip "%s", logical_port "%s", mac "%s"' %
                    (mb.ip, mb.logical_port, mb.mac))

class MulticastGroupHandler(CookieHandlerByUUUID):
    def __init__(self, ovnsb_db):
        super(MulticastGroupHandler, self).__init__(ovnsb_db,
                                                    'Multicast_Group')

    def print_record(self, mc):
        mc_ports = ', '.join([pb.logical_port for pb in mc.ports])

        print_p('Logical datapath: %s' % (datapath_str(mc.datapath)))
        print_p('Multicast Group: name "%s", tunnel_key %ld ports: (%s)' %
                    (mc.name, mc.tunnel_key, mc_ports))

class ChassisHandler(CookieHandlerByUUUID):
    def __init__(self, ovnsb_db):
        super(ChassisHandler, self).__init__(ovnsb_db, 'Chassis')

    def print_record(self, chassis):
        print_p('Chassis: %s' % (chassis_str([chassis])))

class SBLoadBalancerHandler(CookieHandlerByUUUID):
    def __init__(self, ovnsb_db):
        super(SBLoadBalancerHandler, self).__init__(ovnsb_db, 'Load_Balancer')

    def print_record(self, lb):
        print_p('Load Balancer: %s protocol %s vips %s' % (
                    lb.name, lb.protocol, lb.vips))

class OvsInterfaceHandler(CookieHandler):
    def __init__(self, ovs_db):
        super(OvsInterfaceHandler, self).__init__(ovs_db, 'Interface')

        # Store the interfaces connected to the integration bridge in a dict
        # indexed by ofport.
        br = self.get_br_int()
        self._intfs = {
            i.ofport[0] : i for p in br.ports
                            for i in p.interfaces if len(i.ofport) > 0
        }

    def get_br_int(self):
        ovsrec = self._db.get_first_record('Open_vSwitch')
        if ovsrec:
            br_name = ovsrec.external_ids.get('ovn-bridge', 'br-int')
        else:
            br_name = 'br-int'
        return next(iter(self._db.find_rows_by_name('Bridge', br_name)))

    def get_records(self, ofport):
        intf = self._intfs.get(int(ofport))
        return [intf] if intf else []

    def print_record(self, intf):
        print_p('OVS Interface: %s (%s)' %
            (intf.name, intf.external_ids.get('iface-id')))

def print_record_from_cookie(ovnnb_db, cookie_handlers, cookie):
    for handler in cookie_handlers:
        records = list(handler.get_records(cookie))
        for i, record in enumerate(records):
            if i > 0:
                handler.print('[Duplicate uuid cookie]')
            handler.print_record(record)
            handler.print_hint(record, ovnnb_db)

def remote_is_ssl(remote):
    return remote and (remote.startswith('ssl:') or ',ssl:' in remote)

def main():
    try:
        options, args = getopt.gnu_getopt(sys.argv[1:], 'hVp:c:C:',
                                          ['help', 'version', 'ovs',
                                           'ovnsb=', 'ovnnb=', 'ovsdb=',
                                           'private-key=', 'certificate=',
                                           'ca-cert='])
    except (getopt.GetoptError, geo):
        sys.stderr.write("%s: %s\n" % (argv0, geo.msg))
        sys.exit(1)

    ovnsb_db = None
    ovnnb_db = None
    ovs_db   = None
    ovs      = False

    ssl_pk      = None
    ssl_cert    = None
    ssl_ca_cert = None

    for key, value in options:
        if key in ['-h', '--help']:
            usage()
        elif key in ['-V', '--version']:
            print("%s (OVN) 20.12.0" % argv0)
        elif key in ['--ovnsb']:
            ovnsb_db = value
        elif key in ['--ovnnb']:
            ovnnb_db = value
        elif key in ['--ovsdb']:
            ovs_db = value
        elif key in ['-p', '--private-key']:
            ssl_pk = value
        elif key in ['-c', '--certificate']:
            ssl_cert = value
        elif key in ['-C', '--ca-cert']:
            ssl_ca_cert = value
        elif key in ['--ovs']:
            ovs = True
        else:
            sys.exit(0)

    if len(args) != 0:
        sys.stderr.write("%s: non-option argument not supported "
                         "(use --help for help)\n" % argv0)
        sys.exit(1)

    # If at least one of the remotes is SSL, make sure the SSL required args
    # were passed.
    for db in [ovnnb_db, ovnsb_db, ovs_db]:
        if remote_is_ssl(db) and \
                (not ssl_pk or not ssl_cert or not ssl_ca_cert):
            sys.stderr.write('%s: SSL connection requires private key, '
                             'certificate for private key, and peer CA '
                             'certificate as arguments.\n' % argv0)
            sys.exit(1)

    Stream.ssl_set_private_key_file(ssl_pk)
    Stream.ssl_set_certificate_file(ssl_cert)
    Stream.ssl_set_ca_cert_file(ssl_ca_cert)

    ovn_rundir = os.getenv('OVN_RUNDIR', '/var/run/ovn')
    ovs_rundir = os.getenv('OVS_RUNDIR', dirs.RUNDIR)

    if not ovnsb_db:
        ovnsb_db = os.getenv('OVN_SB_DB')
        if not ovnsb_db:
            ovnsb_db = 'unix:%s/ovnsb_db.sock' % ovn_rundir

    if not ovnnb_db:
        ovnnb_db = os.getenv('OVN_NB_DB')
        if not ovnnb_db:
            ovnnb_db = 'unix:%s/ovnnb_db.sock' % ovn_rundir
    if ovs and not ovs_db:
        ovs_db = 'unix:%s/db.sock' % ovs_rundir

    ovsdb_ovnsb = OVSDB(ovnsb_db, 'OVN_Southbound')
    ovsdb_ovnnb = OVSDB(ovnnb_db, 'OVN_Northbound')

    cookie_handlers = [
        LogicalFlowHandler(ovsdb_ovnnb, ovsdb_ovnsb),
        PortBindingHandler(ovsdb_ovnsb),
        MacBindingHandler(ovsdb_ovnsb),
        MulticastGroupHandler(ovsdb_ovnsb),
        ChassisHandler(ovsdb_ovnsb),
        SBLoadBalancerHandler(ovsdb_ovnsb)
    ]

    regex_cookie = re.compile(r'^.*cookie 0x([0-9a-fA-F]+)')
    regex_handlers = [
        (regex_cookie, cookie_handlers)
    ]

    if ovs:
        ovsdb_ovs = OVSDB(ovs_db, 'Open_vSwitch')
        regex_inport = re.compile(r'^ *[0-9]+\. *in_port=([0-9])+')
        regex_outport = re.compile(r'^ *output:([0-9]+)')
        ofport_handlers = [
            OvsInterfaceHandler(ovsdb_ovs)
        ]
        regex_handlers += [
            (regex_outport, ofport_handlers),
            (regex_inport, ofport_handlers)
        ]

    regex_table_id = re.compile(r'^ *[0-9]+\.')
    cookies = []
    while True:
        line = sys.stdin.readline()
        if len(cookies) > 0:
            # Print record info when the current flow block ends.
            if regex_table_id.match(line) or line.strip() == '':
                for cookie, handlers in cookies:
                    print_record_from_cookie(ovsdb_ovnnb, handlers, cookie)
                cookies = []

        print(line.strip())
        if line == '':
            break

        for regex, handlers in regex_handlers:
            m = regex.match(line)
            if m:
                cookies.append((m.group(1), handlers))


if __name__ == "__main__":
    main()


# Local variables:
# mode: python
# End:
