TPU Monitoring in MLFlow Part 2
As promised, this is a follow up post with a cleaner way to track Google’s TPU performance in MLFlow while training your models.
This will spawn a background process (not thread, jaxlib didn’t like that) that will periodically send to the MLFlow Tracking server the relevant TPU information.
import multiprocessing
import mlflow
class TPUMonitoring(multiprocessing.Process):
def __init__(self, run_id, sampling_interval_seconds=1):
super().__init__()
self.run_id = run_id
self.sampling_interval = sampling_interval_seconds
self._stop_event = multiprocessing.Event()
self.client = mlflow.client.MlflowClient()
def log_tpu_mlflow(self, step):
chip_type, _, data = self.tpu_system_metrics()
for tpu in data:
for metric, val in tpu.items():
if metric != "device_id":...



