Testing PyTorch XLA with Google Colab TPUs
If you are not aware, PyTorch XLA project is an effort to run PyTorch on TPU (Tensor Processing Unit) architecture which offers even higher performance in training Deep Learning models compared to GPU’s.
In my previous post I showed you how to use serveo.net to SSH into your Google Colab provided container.
In this short post I will show you how to use the same principle to connect to provided free TPUs. Beware that PyTorch XLA is in its inception phase and you may run into bugs.
Open a Colab notebook and set it up (Runtime -> Runtime Type -> Select Python 3 and TPU)
Create a code cell and paste the following:
import os tpu_addr = os.environ['COLAB_TPU_ADDR'] print('Local TPU address', tpu_addr) import random, string, getpass password = ''.join(random.choice(string.ascii_letters + string.digits) for i in range(20)) alias = ''.join(random.choice(string.ascii_letters + string.digits) for i in range(8)) ! echo root:$password | chpasswd ! apt-get update > /dev/null ! apt-get install -qq -o=Dpkg::Use-Pty=0 openssh-server sshpass iputils-ping vim tmux htop > /dev/null ! mkdir -p /var/run/sshd ! echo -e "PermitRootLogin yes \n PasswordAuthentication yes \n" >> /etc/ssh/sshd_config ! echo -e "export LD_LIBRARY_PATH=/usr/lib64-nvidia \n" >> /root/.bashrc get_ipython().system_raw('/usr/sbin/sshd -D &') print('Public TPU address', alias + '.serveo.net:8470') ! ssh -o "StrictHostKeyChecking no" -R $alias:8470:$tpu_addr serveo.net
Run the cell, you should see output similar to:
Local TPU address 10.16.149.66:8470 Public TPU address JtkUVVKo.serveo.net:8470 Forwarding TCP connections from serveo.net:8470 Press g to start a GUI session and ctrl-c to quit.
Congrats, your TPUs are publicly available for you. You can see what your TPU address is, the public one is important.
Now, on your local machine, we will pull the PyTorch XLA Docker image and run it. Sadly since Google Colab is already containerized we cannot run Docker on it. Be sure to check Docker installation instructions for your OS. On Fedora it’s a simple dnf install docker
and reboot. After installation run:
docker pull gcr.io/tpu-pytorch/xla:nightly
to get the image (~6GB)docker run -it --shm-size 4G gcr.io/tpu-pytorch/xla:nightly
to get to the shell. Adjustshm-size
based on your RAM.Now inside the container we need to tell it where the TPUs are:
(pytorch) root@CONTAINERID:/$ export XRT_TPU_CONFIG="tpu_worker;0;[Your Public TPU Address e.g. JtkUVVKo.serveo.net:8470]"
And run some tests:
(pytorch) root@CONTAINERID:/$ python pytorch/xla/test/test_train_mnist.py
The output should look like this:
(pytorch) root@8f8bf26dc0c6:/# python pytorch/xla/test/test_train_mnist.py
2019-08-26 07:52:49.969079: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) CPU:0 -> /job:tpu_worker/replica:0/task:0/device:XLA_CPU:0
2019-08-26 07:52:49.969165: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:0 -> /job:tpu_worker/replica:0/task:0/device:TPU:0
2019-08-26 07:52:49.969208: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:1 -> /job:tpu_worker/replica:0/task:0/device:TPU:1
2019-08-26 07:52:49.969242: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:2 -> /job:tpu_worker/replica:0/task:0/device:TPU:2
2019-08-26 07:52:49.969266: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:3 -> /job:tpu_worker/replica:0/task:0/device:TPU:3
2019-08-26 07:52:49.969290: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:4 -> /job:tpu_worker/replica:0/task:0/device:TPU:4
2019-08-26 07:52:49.969310: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:5 -> /job:tpu_worker/replica:0/task:0/device:TPU:5
2019-08-26 07:52:49.969328: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:6 -> /job:tpu_worker/replica:0/task:0/device:TPU:6
2019-08-26 07:52:49.969346: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:196] XRT device (LOCAL) TPU:7 -> /job:tpu_worker/replica:0/task:0/device:TPU:7
2019-08-26 07:52:49.969385: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:200] Worker grpc://SKuPrJqO.serveo.net:8470 for /job:tpu_worker/replica:0/task:0
2019-08-26 07:52:49.969413: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:204] XRT default device: TPU:0
2019-08-26 07:52:49.969813: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:1086] Configuring TPU for worker tpu_worker:0 at grpc://SKuPrJqO.serveo.net:8470
2019-08-26 07:52:58.647554: I tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:1102] TPU topology: mesh_shape: 2
mesh_shape: 2
mesh_shape: 2
num_tasks: 1
num_tpu_devices_per_task: 8
device_coordinates: 0
device_coordinates: 0
device_coordinates: 0
device_coordinates: 0
device_coordinates: 0
device_coordinates: 1
device_coordinates: 0
device_coordinates: 1
device_coordinates: 0
device_coordinates: 0
device_coordinates: 1
device_coordinates: 1
device_coordinates: 1
device_coordinates: 0
device_coordinates: 0
device_coordinates: 1
device_coordinates: 0
device_coordinates: 1
device_coordinates: 1
device_coordinates: 1
device_coordinates: 0
device_coordinates: 1
device_coordinates: 1
device_coordinates: 1
[xla:1](0) Loss=2.42430 Rate=39.84
[xla:2](0) Loss=2.39250 Rate=32.47
[xla:3](0) Loss=2.37206 Rate=31.64
[xla:5](0) Loss=2.38739 Rate=31.63
[xla:4](0) Loss=2.41775 Rate=31.34
[xla:8](0) Loss=2.39084 Rate=29.97
[xla:7](0) Loss=2.36097 Rate=27.51
[xla:6](0) Loss=2.38354 Rate=26.88
[xla:5](20) Loss=0.50134 Rate=76.40
[xla:8](20) Loss=0.52967 Rate=76.35
[xla:6](20) Loss=0.43902 Rate=76.23
[xla:1](20) Loss=0.42783 Rate=78.70
[xla:3](20) Loss=0.46408 Rate=75.88
[xla:7](20) Loss=0.47769 Rate=76.00
[xla:2](20) Loss=0.51218 Rate=76.01
[xla:4](20) Loss=0.51220 Rate=75.82
[xla:3](40) Loss=0.22904 Rate=140.75
[xla:2](40) Loss=0.13915 Rate=140.54
[xla:8](40) Loss=0.26900 Rate=139.75
[xla:6](40) Loss=0.24324 Rate=140.33
[xla:4](40) Loss=0.25083 Rate=140.38
[xla:1](40) Loss=0.27989 Rate=142.32
[xla:5](40) Loss=0.22283 Rate=139.64
[xla:7](40) Loss=0.18603 Rate=140.28
[xla:3] Accuracy=95.83%
[xla:1] Accuracy=95.92%
[xla:8] Accuracy=96.09%
[xla:2] Accuracy=97.22%
[xla:7] Accuracy=96.53%
[xla:4] Accuracy=96.61%
[xla:5] Accuracy=96.09%
[xla:6] Accuracy=95.83%
...
That’s it, hope it works for you too! As for performance, it is difficult to tell since we are running over a SSH tunnel. Still, this is a nice way to test your code is TPU compatible. :)