TPU Monitoring in MLFlow Part 2

Huge thanks to Google TPU Research Cloud for providing me with access to TPU chips!

As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.

This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.

import multiprocessing
import mlflow

class TPUMonitoring(multiprocessing.Process):
    def __init__(self, run_id, sampling_interval_seconds=1):
        super().__init__()
        self.run_id = run_id
        self.sampling_interval = sampling_interval_seconds
        self._stop_event = multiprocessing.Event()
        self.client = mlflow.client.MlflowClient()

    def log_tpu_mlflow(self, step):
        chip_type, _, data = self.tpu_system_metrics()
        for tpu in data:
            for metric, val in tpu.items():
                if metric != "device_id":
                    self.client.log_metric(key=f"system/{chip_type}/{tpu['device_id']}/{metric}", value=float(val), step=step, run_id=self.run_id)


    def tpu_system_metrics(self):
        try:
            import tpu_info, tpu_info.metrics, tpu_info.device
            chip_type, count = tpu_info.device.get_local_chips()
            if chip_type is not None:
                data = tpu_info.metrics.get_chip_usage(chip_type)
                data = [{"device_id": d.device_id, "memory_usage": d.memory_usage, "total_memory": d.total_memory, "duty_cycle_pct": d.duty_cycle_pct} for d in data]
                return chip_type, count, data
            else:
                return "no tpu detected", 0, []
        except Exception as e:
            print(e)
            return "no tpu detected", 0, []

    def run(self):
        try:
            step = 0
            while not self._stop_event.is_set():
                self.log_tpu_mlflow(step) 
                step += 1
                self._stop_event.wait(self.sampling_interval)
        except Exception as e:
            print(f"Error in background logger process: {e}")

    def stop(self):
        self._stop_event.set()

To use this in your experiment you need something like this:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    TPUMonitoring(run_id=run.info.run_id).start()
    ...

And that’s it!

If you have logging enabled you might get some info/warning messages from protobuf, tpu_info but they are safe to ignore from my experience.

I wanted to also attach the start function to the MLFlow module with something like:

    mlflow.enable_tpu_metrics_logging = TPUMonitoring(run_id=run.info.run_id).start

So that it wood look nicer and ‘built-in’ like:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    mlflow.enable_tpu_metrics_logging()

But I leave that as future work.

 
2
Kudos
 
2
Kudos

Now read this

Fedora Crouton connector (Chrome App) - Fedora Crouton update

I’ve created a small Chrome App for if you are using my version of Crouton with Fedora which makes it possible to quickly connect to the running VNC server. RealVNC was great to me, but I couldn’t really configure it to my liking. I... Continue →