TPU Monitoring in MLFlow Part 2

As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.

This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.

import multiprocessing
import mlflow

class TPUMonitoring(multiprocessing.Process):
    def __init__(self, run_id, sampling_interval_seconds=1):
        super().__init__()
        self.run_id = run_id
        self.sampling_interval = sampling_interval_seconds
        self._stop_event = multiprocessing.Event()
        self.client = mlflow.client.MlflowClient()

    def log_tpu_mlflow(self, step):
        chip_type, _, data = self.tpu_system_metrics()
        for tpu in data:
            for metric, val in tpu.items():
                if metric != "device_id":
                    self.client.log_metric(key=f"system/{chip_type}/{tpu['device_id']}/{metric}", value=float(val), step=step, run_id=self.run_id)


    def tpu_system_metrics(self):
        try:
            import tpu_info, tpu_info.metrics, tpu_info.device
            chip_type, count = tpu_info.device.get_local_chips()
            if chip_type is not None:
                data = tpu_info.metrics.get_chip_usage(chip_type)
                data = [{"device_id": d.device_id, "memory_usage": d.memory_usage, "total_memory": d.total_memory, "duty_cycle_pct": d.duty_cycle_pct} for d in data]
                return chip_type, count, data
            else:
                return "no tpu detected", 0, []
        except Exception as e:
            print(e)
            return "no tpu detected", 0, []

    def run(self):
        try:
            step = 0
            while not self._stop_event.is_set():
                self.log_tpu_mlflow(step) 
                step += 1
                self._stop_event.wait(self.sampling_interval)
        except Exception as e:
            print(f"Error in background logger process: {e}")

    def stop(self):
        self._stop_event.set()

To use this in your experiment you need something like this:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    TPUMonitoring(run_id=run.info.run_id).start()
    ...

And that’s it!

If you have logging enabled you might get some info/warning messages from protobuf, tpu_info but they are safe to ignore from my experience.

I wanted to also attach the start function to the MLFlow module with something like:

    mlflow.enable_tpu_metrics_logging = TPUMonitoring(run_id=run.info.run_id).start

So that it wood look nicer and ‘built-in’ like:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    mlflow.enable_tpu_metrics_logging()

But I leave that as future work.

 
1
Kudos
 
1
Kudos

Now read this

Fedora 24 + Bumblebee + CUDA + Theano

It’s very frustrating when you want to try something out quickly and lose your entire day… But at least next time it will be easier. I hope. This is a guide on how to run Keras (Theano powered neural network Python library) on CUDA with... Continue →