#!/usr/bin/env python

# Copyright Amazon.com, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the MIT License. See the LICENSE accompanying this file
# for the specific language governing permissions and limitations under
# the License.

"""
Usage:
Reads EBS information from EBS NVMe device
"""

from __future__ import print_function
import argparse
from ctypes import Structure, Array, c_uint8, c_uint16, c_uint32, c_uint64, \
    c_char, addressof, sizeof, byref
from fcntl import ioctl
import json
import os
import signal
import sys
import time

NVME_ADMIN_IDENTIFY = 0x06
NVME_GET_LOG_PAGE = 0x02
NVME_IOCTL_ADMIN_CMD = 0xC0484E41

AMZN_NVME_EBS_MN = "Amazon Elastic Block Store"
AMZN_NVME_STATS_LOGPAGE_ID = 0xD0
AMZN_NVME_STATS_MAGIC = 0x3C23B510
AMZN_NVME_VID = 0x1D0F


class structure_dict_mixin:
    def to_dict(self):
        return {
            field[0]: getattr(self, field[0])
            for field in self._fields_
            if not field[0].startswith("_") and
            not isinstance(getattr(self, field[0]), (Structure, Array))
        }


class nvme_admin_command(Structure):
    _pack_ = 1
    _fields_ = [("opcode", c_uint8),
                ("flags", c_uint8),
                ("cid", c_uint16),
                ("nsid", c_uint32),
                ("_reserved0", c_uint64),
                ("mptr", c_uint64),
                ("addr", c_uint64),
                ("mlen", c_uint32),
                ("alen", c_uint32),
                ("cdw10", c_uint32),
                ("cdw11", c_uint32),
                ("cdw12", c_uint32),
                ("cdw13", c_uint32),
                ("cdw14", c_uint32),
                ("cdw15", c_uint32),
                ("_reserved1", c_uint64)]


class nvme_identify_controller_amzn_vs(Structure):
    _pack_ = 1
    _fields_ = [("bdev", c_char * 32),  # block device name
                ("_reserved0", c_char * (1024 - 32))]


class nvme_identify_controller_psd(Structure):
    _pack_ = 1
    _fields_ = [("mp", c_uint16),       # maximum power
                ("_reserved0", c_uint16),
                ("enlat", c_uint32),    # entry latency
                ("exlat", c_uint32),    # exit latency
                ("rrt", c_uint8),       # relative read throughput
                ("rrl", c_uint8),       # relative read latency
                ("rwt", c_uint8),       # relative write throughput
                ("rwl", c_uint8),       # relative write latency
                ("_reserved1", c_char * 16)]


class nvme_identify_controller(Structure):
    _pack_ = 1
    _fields_ = [("vid", c_uint16),          # PCI Vendor ID
                ("ssvid", c_uint16),        # PCI Subsystem Vendor ID
                ("sn", c_char * 20),        # Serial Number
                ("mn", c_char * 40),        # Module Number
                ("fr", c_char * 8),         # Firmware Revision
                ("rab", c_uint8),           # Recommend Arbitration Burst
                ("ieee", c_uint8 * 3),      # IEEE OUI Identifier
                ("mic", c_uint8),           # Multi-Interface Capabilities
                ("mdts", c_uint8),          # Maximum Data Transfer Size
                ("_reserved0", c_uint8 * (256 - 78)),
                ("oacs", c_uint16),         # Optional Admin Command Support
                ("acl", c_uint8),           # Abort Command Limit
                ("aerl", c_uint8),          # Asynchronous Event Request Limit
                ("frmw", c_uint8),          # Firmware Updates
                ("lpa", c_uint8),           # Log Page Attributes
                ("elpe", c_uint8),          # Error Log Page Entries
                ("npss", c_uint8),          # Number of Power States Support
                ("avscc", c_uint8),         # Admin Vendor Specific Command Configuration # noqa
                ("_reserved1", c_uint8 * (512 - 265)),
                ("sqes", c_uint8),          # Submission Queue Entry Size
                ("cqes", c_uint8),          # Completion Queue Entry Size
                ("_reserved2", c_uint16),
                ("nn", c_uint32),           # Number of Namespaces
                ("oncs", c_uint16),         # Optional NVM Command Support
                ("fuses", c_uint16),        # Fused Operation Support
                ("fna", c_uint8),           # Format NVM Attributes
                ("vwc", c_uint8),           # Volatile Write Cache
                ("awun", c_uint16),         # Atomic Write Unit Normal
                ("awupf", c_uint16),        # Atomic Write Unit Power Fail
                ("nvscc", c_uint8),         # NVM Vendor Specific Command Configuration # noqa
                ("_reserved3", c_uint8 * (704 - 531)),
                ("_reserved4", c_uint8 * (2048 - 704)),
                ("psd", nvme_identify_controller_psd * 32),  # Power State Descriptor # noqa
                ("vs", nvme_identify_controller_amzn_vs)]  # Vendor Specific


class nvme_histogram_bin(Structure, structure_dict_mixin):
    _pack_ = 1
    _fields_ = [("lower", c_uint64),
                ("upper", c_uint64),
                ("count", c_uint32),
                ("_reserved0", c_uint32)]

    def to_human_readable(self):
        print("[{0.lower:<8} - {0.upper:<8}] => {0.count}".format(self))


class ebs_nvme_histogram(Structure, structure_dict_mixin):
    _pack_ = 1
    _fields_ = [("num_bins", c_uint64),
                ("bins", nvme_histogram_bin * 64)]

    def to_dict(self):
        dict = super(ebs_nvme_histogram, self).to_dict()
        dict["bins"] = [self.bins[i].to_dict() for i in range(self.num_bins)]

        return dict

    def to_human_readable(self):
        print("Number of bins: {0}".format(self.num_bins))

        print("=================================")
        print("Lower       Upper        IO Count")
        print("=================================")

        for i in range(self.num_bins):
            self.bins[i].to_human_readable()


class nvme_get_amzn_stats_logpage(Structure, structure_dict_mixin):
    _pack_ = 1
    _fields_ = [("_magic", c_uint32),
                ("_reserved0", c_char * 4),
                ("total_read_ops", c_uint64),       # total number of read operations
                ("total_write_ops", c_uint64),      # total number of write operations
                ("total_read_bytes", c_uint64),     # total bytes read
                ("total_write_bytes", c_uint64),    # total bytes written
                ("total_read_time", c_uint64),      # total time spent on read operations (in microseconds)
                ("total_write_time", c_uint64),     # total time spent on write operations (in microseconds)
                ("ebs_volume_performance_exceeded_iops", c_uint64),       # time EBS volume IOPS limit was exceeded (in microseconds)
                ("ebs_volume_performance_exceeded_tp", c_uint64),         # time EBS volume throughput limit was exceeded (in microseconds)
                ("ec2_instance_ebs_performance_exceeded_iops", c_uint64), # time EC2 instance EBS IOPS limit was exceeded (in microseconds)
                ("ec2_instance_ebs_performance_exceeded_tp", c_uint64),   # time EC2 instance EBS throughput limit was exceeded (in microseconds)
                ("volume_queue_length", c_uint64),  # current volume queue length
                ("_reserved1", c_char * 416),
                ("read_io_latency_histogram", ebs_nvme_histogram), # histogram of read I/O latencies (in microseconds)
                ("write_io_latency_histogram", ebs_nvme_histogram), # histogram of write I/O latencies (in microseconds)
                ("_reserved2", c_char * 496)]

    def to_dict(self):
        dict = super(nvme_get_amzn_stats_logpage, self).to_dict()
        dict["read_io_latency_histogram"] = self.read_io_latency_histogram.to_dict()
        dict["write_io_latency_histogram"] = self.write_io_latency_histogram.to_dict()

        return dict

    def to_json(self):
        print(json.dumps(self.to_dict()))

    def to_human_readable(self):
        print("Total Ops")
        print("  Read: {0}".format(self.total_read_ops))
        print("  Write: {0}".format(self.total_write_ops))
        print("Total Bytes")
        print("  Read: {0}".format(self.total_read_bytes))
        print("  Write: {0}".format(self.total_write_bytes))
        print("Total Time (us)")
        print("  Read: {0}".format(self.total_read_time))
        print("  Write: {0}".format(self.total_write_time))

        print("EBS Volume Performance Exceeded (us)")
        print("  IOPS: {0}".format(self.ebs_volume_performance_exceeded_iops))
        print("  Throughput: {0}".format(self.ebs_volume_performance_exceeded_tp))
        print("EC2 Instance EBS Performance Exceeded (us)")
        print("  IOPS: {0}".format(self.ec2_instance_ebs_performance_exceeded_iops))
        print("  Throughput: {0}".format(self.ec2_instance_ebs_performance_exceeded_tp))

        print("Queue Length (point in time): {0} \n".format(self.volume_queue_length))

        print("Read IO Latency Histogram (us)")
        self.read_io_latency_histogram.to_human_readable()

        print("\nWrite IO Latency Histogram (us)")
        self.write_io_latency_histogram.to_human_readable()


class ebs_nvme_device:
    def __init__(self, device):
        self.device = device

    def _nvme_ioctl(self, admin_cmd):

        with open(self.device, "r") as dev:
            try:
                ioctl(dev, NVME_IOCTL_ADMIN_CMD, admin_cmd)
            except (OSError) as err:
                print("Failed to issue nvme cmd, err: ", err)
                sys.exit(1)


class ebs_nvme_device_stats(ebs_nvme_device):
    def _query_stats_from_device(self):
        stats = nvme_get_amzn_stats_logpage()
        admin_cmd = nvme_admin_command(
            opcode=NVME_GET_LOG_PAGE,
            addr=addressof(stats),
            alen=sizeof(stats),
            nsid=1,
            cdw10=(AMZN_NVME_STATS_LOGPAGE_ID | (1024 << 16))
        )
        self._nvme_ioctl(admin_cmd)

        if stats._magic != AMZN_NVME_STATS_MAGIC:
            raise TypeError("[ERROR] Not an EBS device: {0}".format(self.device))

        return stats

    def _get_stats_diff(self):
        curr = self._query_stats_from_device()
        if self.prev is None:
            self.prev = curr
            return curr

        diff = nvme_get_amzn_stats_logpage()
        diff.volume_queue_length = curr.volume_queue_length

        for field, _ in nvme_get_amzn_stats_logpage._fields_:
            if field.startswith('_') or field == "volume_queue_length":
                continue
            if isinstance(getattr(self.prev, field), (int)):
                setattr(diff, field, getattr(curr, field) - getattr(self.prev, field))

        for histogram_field in ['read_io_latency_histogram', 'write_io_latency_histogram']:
            self._calculate_histogram_diff(diff, curr, self.prev, histogram_field)

        self.prev = curr
        return diff

    def _calculate_histogram_diff(self, diff, curr, prev, histogram_field):
        prev_hist = getattr(prev, histogram_field)
        curr_hist = getattr(curr, histogram_field)
        diff_hist = getattr(diff, histogram_field)

        diff_hist.num_bins = curr_hist.num_bins
        for i in range(diff_hist.num_bins):
            diff_hist.bins[i].lower = curr_hist.bins[i].lower
            diff_hist.bins[i].upper = curr_hist.bins[i].upper
            diff_hist.bins[i].count = curr_hist.bins[i].count - prev_hist.bins[i].count

    def _print_stats(self, stats, json_format=False):
        if json_format:
            print(json.dumps(stats.to_dict()))
        else:
            stats.to_human_readable()

    def _signal_handler(self, sig, frame):
        sys.exit(0)

    def get_stats(self, interval=0, json_format=False):
        if interval > 0:
            print("Polling EBS stats every {0} sec(s); press Ctrl+C to stop".format(interval))
            signal.signal(signal.SIGINT, self._signal_handler)
            self.prev = None

            while True:
                self._print_stats(self._get_stats_diff(), json_format)
                time.sleep(interval)
                print("\n")

        else:
            self._print_stats(self._query_stats_from_device(), json_format)

class ebs_nvme_device_id(ebs_nvme_device):
    def get_id(self, volume=False, block_dev=False, udev=False):
        id_ctrl = self._query_id_ctrl_from_device()

        if not (volume or block_dev or udev):
            print("Volume ID: {0}".format(self._get_volume_id(id_ctrl)))
            print(self._get_block_device(id_ctrl))
        else:
            if volume:
                print("Volume ID: {0}".format(self._get_volume_id(id_ctrl)))
            if block_dev or udev:
                print(self._get_block_device(id_ctrl, udev))

    def _query_id_ctrl_from_device(self):
        id_ctrl = nvme_identify_controller()
        admin_cmd = nvme_admin_command(
            opcode=NVME_ADMIN_IDENTIFY,
            addr=addressof(id_ctrl),
            alen=sizeof(id_ctrl),
            cdw10=1
        )
        self._nvme_ioctl(admin_cmd)

        if id_ctrl.vid != AMZN_NVME_VID or id_ctrl.mn.decode().strip() != AMZN_NVME_EBS_MN:
            raise TypeError("[ERROR] Not an EBS device: ", self.device)

        return id_ctrl

    def _get_volume_id(self, id_ctrl):
        vol = id_ctrl.sn.decode()
        if vol.startswith("vol") and vol[3] != "-":
            vol = "vol-" + vol[3:]
        return vol

    def _get_block_device(self, id_ctrl, stripped=False):
        dev = id_ctrl.vs.bdev.decode().strip()
        if stripped and dev.startswith("/dev/"):
            dev = dev[5:]
        return dev


if __name__ == "__main__":

    # check if the script is being called as ebsnvme-id
    if os.path.basename(sys.argv[0]) == 'ebsnvme-id':
        sys.argv.insert(1, 'id')

    parser = argparse.ArgumentParser(description="Reads EBS information from EBS NVMe devices.")
    cmd_parser = parser.add_subparsers(dest="cmd", help="Available commands")
    cmd_parser.required = True

    stats_parser = cmd_parser.add_parser("stats", help="Get EBS NVMe stats")
    stats_parser.add_argument("device", help="Device to query")
    stats_parser.required = True
    stats_parser.add_argument("-j", "--json", action="store_true",
                         help="Output in json format")
    stats_parser.add_argument("-i", "--interval", type=int, default=0,
                              help='Interval in seconds to poll ebs stats')

    id_parser = cmd_parser.add_parser("id", help="Display id options")
    id_parser.add_argument("device", help="Device to query")
    id_parser.required = True
    id_parser.add_argument("-v", "--volume", action="store_true",
                           help="Return volume-id")
    id_parser.add_argument("-b", "--block-dev", action="store_true",
                           help="Return block device mapping")
    id_parser.add_argument("-u", "--udev", action="store_true",
                           help="Output data in format suitable for udev rules")

    args = parser.parse_args()

    try:
        if args.cmd == "stats":
            stats = ebs_nvme_device_stats(args.device)
            stats.get_stats(interval=args.interval, json_format=args.json)

        elif args.cmd == "id":
            id_info = ebs_nvme_device_id(args.device)
            id_info.get_id(volume=args.volume, block_dev=args.block_dev, udev=args.udev)
    except (IOError, TypeError) as err:
        print(err, file=sys.stderr)
        sys.exit(1)
