TPU Monitoring in MLFlow Part 2

Huge thanks to Google TPU Research Cloud for providing me with access to TPU chips!

As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.

This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.

import multiprocessing
import mlflow

class TPUMonitoring(multiprocessing.Process):
    def __init__(self, run_id, sampling_interval_seconds=1):
        super().__init__()
        self.run_id = run_id
        self.sampling_interval = sampling_interval_seconds
        self._stop_event = multiprocessing.Event()
        self.client = mlflow.client.MlflowClient()

    def log_tpu_mlflow(self, step):
        chip_type, _, data = self.tpu_system_metrics()
        for tpu in data:
            for metric, val in tpu.items():
                if metric != "device_id":
                    self.client.log_metric(key=f"system/{chip_type}/{tpu['device_id']}/{metric}", value=float(val), step=step, run_id=self.run_id)


    def tpu_system_metrics(self):
        try:
            import tpu_info, tpu_info.metrics, tpu_info.device
            chip_type, count = tpu_info.device.get_local_chips()
            if chip_type is not None:
                data = tpu_info.metrics.get_chip_usage(chip_type)
                data = [{"device_id": d.device_id, "memory_usage": d.memory_usage, "total_memory": d.total_memory, "duty_cycle_pct": d.duty_cycle_pct} for d in data]
                return chip_type, count, data
            else:
                return "no tpu detected", 0, []
        except Exception as e:
            print(e)
            return "no tpu detected", 0, []

    def run(self):
        try:
            step = 0
            while not self._stop_event.is_set():
                self.log_tpu_mlflow(step) 
                step += 1
                self._stop_event.wait(self.sampling_interval)
        except Exception as e:
            print(f"Error in background logger process: {e}")

    def stop(self):
        self._stop_event.set()

To use this in your experiment you need something like this:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    TPUMonitoring(run_id=run.info.run_id).start()
    ...

And that’s it!

If you have logging enabled you might get some info/warning messages from protobuf, tpu_info but they are safe to ignore from my experience.

I wanted to also attach the start function to the MLFlow module with something like:

    mlflow.enable_tpu_metrics_logging = TPUMonitoring(run_id=run.info.run_id).start

So that it wood look nicer and ‘built-in’ like:

with mlflow.start_run() as run:
    mlflow.enable_system_metrics_logging()
    mlflow.enable_tpu_metrics_logging()

But I leave that as future work.

 
2
Kudos
 
2
Kudos

Now read this

Quick guide for running Android applications and games on Windows 8.1 tablet or PC

If you want to run Android on your W8.1 tablet or PC for fun or development, there are several options to try. So here they are: All things shown in this tutorial should work on Windows 7 too. Official Android Emulator (based on QEMU) #... Continue →