TPU Monitoring in MLFlow Part 2

Huge thanks to Google TPU Research Cloud for providing me with access to TPU chips!

As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.

This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.

import multiprocessing
import mlflow

class TPUMonitoring(multiprocessing.Process):
    def __init__(self, run_id, sampling_interval_seconds=1):
        super().__init__()
        self.run_id = run_id
        self.sampling_interval = sampling_interval_seconds
        self._stop_event = multiprocessing.Event()
        self.client = mlflow.client.MlflowClient()

    def log_tpu_mlflow(self, step):
        chip_type, _, data = self.tpu_system_metrics()
        for tpu in data:
            for metric, val in tpu.items():
                if metric != "device_id":
                    self.client.log_metric(key=f"system/{chip_type}/{tpu['device_id']}/{metric}", value=float(val), step=step, run_id=self.run_id)


    def tpu_system_metrics(self):
        try:
            import tpu_info, tpu_info.metrics, tpu_info.device
            chip_type, count = tpu_info.device.get_local_chips()
            if chip_type is not None:
                data = tpu_info.metrics.get_chip_usage(chip_type)
                data = [{"device_id": d.device_id, "memory_usage": d.memory_usage, "total_memory": d.total_memory, "duty_cycle_pct": d.duty_cycle_pct} for d in data]
                return chip_type, count, data
            else:
                return "no tpu detected", 0, []
        except Exception as e:
            print(e)
            return "no tpu detected", 0, []

    def run(self):
        try:
            step = 0
            while not self._stop_event.is_set():
                self.log_tpu_mlflow(step) 
                step += 1
                self._stop_event.wait(self.sampling_interval)
        except Exception as e:
            print(f"Error in background logger process: {e}")

    def stop(self):
        self._stop_event.set()

To use this in your experiment you need something like this:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    TPUMonitoring(run_id=run.info.run_id).start()
    ...

And that’s it!

If you have logging enabled you might get some info/warning messages from protobuf, tpu_info but they are safe to ignore from my experience.

I wanted to also attach the start function to the MLFlow module with something like:

    mlflow.enable_tpu_metrics_logging = TPUMonitoring(run_id=run.info.run_id).start

So that it wood look nicer and ‘built-in’ like:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    mlflow.enable_tpu_metrics_logging()

But I leave that as future work.

 
2
Kudos
 
2
Kudos

Now read this

Running Fedora 22 Cloud Edition Virtual Machine on Windows Azure

Windows Azure is a great cloud computing platform. Sadly the support for Fedora virtual machines is virtually zero as of today (September 2015), so we are going to change that. I’ll show you how to run a virtual machine with Fedora 22... Continue →