何もしなくても全GPUを使って学習してくれるかと思ってたけどそうではなかった。 以下のようにしてGPU数を渡してやると全GPUを使ってくれる。
def get_available_gpus(): local_device_protos = device_lib.list_local_devices() return [x.name for x in local_device_protos if x.device_type == 'GPU'] n_gpu = len(get_available_gpus()) if n_gpu >= 2: model = multi_gpu_model(model, gpus=n_gpu)
参考: