TPU Monitoring in MLFlow Part 2

Huge thanks to Google TPU Research Cloud for providing me with access to TPU chips!

As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.

This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.

import multiprocessing
import mlflow

class TPUMonitoring(multiprocessing.Process):
    def __init__(self, run_id, sampling_interval_seconds=1):
        super().__init__()
        self.run_id = run_id
        self.sampling_interval = sampling_interval_seconds
        self._stop_event = multiprocessing.Event()
        self.client = mlflow.client.MlflowClient()

    def log_tpu_mlflow(self, step):
        chip_type, _, data = self.tpu_system_metrics()
        for tpu in data:
            for metric, val in tpu.items():
                if metric != "device_id":
                    self.client.log_metric(key=f"system/{chip_type}/{tpu['device_id']}/{metric}", value=float(val), step=step, run_id=self.run_id)


    def tpu_system_metrics(self):
        try:
            import tpu_info, tpu_info.metrics, tpu_info.device
            chip_type, count = tpu_info.device.get_local_chips()
            if chip_type is not None:
                data = tpu_info.metrics.get_chip_usage(chip_type)
                data = [{"device_id": d.device_id, "memory_usage": d.memory_usage, "total_memory": d.total_memory, "duty_cycle_pct": d.duty_cycle_pct} for d in data]
                return chip_type, count, data
            else:
                return "no tpu detected", 0, []
        except Exception as e:
            print(e)
            return "no tpu detected", 0, []

    def run(self):
        try:
            step = 0
            while not self._stop_event.is_set():
                self.log_tpu_mlflow(step) 
                step += 1
                self._stop_event.wait(self.sampling_interval)
        except Exception as e:
            print(f"Error in background logger process: {e}")

    def stop(self):
        self._stop_event.set()

To use this in your experiment you need something like this:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    TPUMonitoring(run_id=run.info.run_id).start()
    ...

And that’s it!

If you have logging enabled you might get some info/warning messages from protobuf, tpu_info but they are safe to ignore from my experience.

I wanted to also attach the start function to the MLFlow module with something like:

    mlflow.enable_tpu_metrics_logging = TPUMonitoring(run_id=run.info.run_id).start

So that it wood look nicer and ‘built-in’ like:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    mlflow.enable_tpu_metrics_logging()

But I leave that as future work.

 
2
Kudos
 
2
Kudos

Now read this

self.liberate() – Writing a game in Python for Android was a liberating experience

You write code every day. Whether you are a student or it is for your job or just for fun, most of the time you find yourself writing some code. Even when you are not writing code, you are thinking about it. On your commute you are on HN... Continue →