TPU Monitoring in MLFlow Part 2
As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.
This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.
import multiprocessing
import mlflow
class TPUMonitoring(multiprocessing.Process):
def __init__(self, run_id, sampling_interval_seconds=1):
super().__init__()
self.run_id = run_id
self.sampling_interval = sampling_interval_seconds
self._stop_event = multiprocessing.Event()
self.client = mlflow.client.MlflowClient()
def log_tpu_mlflow(self, step):
chip_type, _, data = self.tpu_system_metrics()
for tpu in data:
for metric, val in tpu.items():
if metric != "device_id":
self.client.log_metric(key=f"system/{chip_type}/{tpu['device_id']}/{metric}", value=float(val), step=step, run_id=self.run_id)
def tpu_system_metrics(self):
try:
import tpu_info, tpu_info.metrics, tpu_info.device
chip_type, count = tpu_info.device.get_local_chips()
if chip_type is not None:
data = tpu_info.metrics.get_chip_usage(chip_type)
data = [{"device_id": d.device_id, "memory_usage": d.memory_usage, "total_memory": d.total_memory, "duty_cycle_pct": d.duty_cycle_pct} for d in data]
return chip_type, count, data
else:
return "no tpu detected", 0, []
except Exception as e:
print(e)
return "no tpu detected", 0, []
def run(self):
try:
step = 0
while not self._stop_event.is_set():
self.log_tpu_mlflow(step)
step += 1
self._stop_event.wait(self.sampling_interval)
except Exception as e:
print(f"Error in background logger process: {e}")
def stop(self):
self._stop_event.set()
To use this in your experiment you need something like this:
with mlflow.start_run() as run:
mlflow.enable_system_metrics_logging()
TPUMonitoring(run_id=run.info.run_id).start()
...
And that’s it!
If you have logging enabled you might get some info/warning messages from protobuf, tpu_info but they are safe to ignore from my experience.
I wanted to also attach the start function to the MLFlow module with something like:
mlflow.enable_tpu_metrics_logging = TPUMonitoring(run_id=run.info.run_id).start
So that it wood look nicer and ‘built-in’ like:
with mlflow.start_run() as run:
mlflow.enable_system_metrics_logging()
mlflow.enable_tpu_metrics_logging()
But I leave that as future work.