Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.GPU #835

Merged
merged 3 commits into from
Jul 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 273 additions & 0 deletions GPU.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
------------------------------------------------------------------------
--[[ GPU ]]--
-- Decorates a module such that its parameters are
-- hosted on a specified GPU device.
-- The operations are also executed on that device.
-- Arguments input and gradOutput are converted to the specified device
-- before being fed to the decorated module.
-- Returned output is on the specified outdevice (defaults to device).
-- Returned gradInput is allocated on the same device as the input.
-- The unit test is located in cunn.
------------------------------------------------------------------------
local GPU, parent = torch.class("nn.GPU", "nn.Container")

function GPU:__init(module, device, outdevice)
parent.__init(self)
assert(torch.type(device) == 'number')
self.device = device
self.outdevice = outdevice or device

assert(torch.isTypeOf(module, 'nn.Module'))
self.modules[1] = module

if module:type() == 'torch.CudaTensor' then
self:cuda()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this :cuda() is no executing in the context of "device". needs a fix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end
end

function GPU.recursiveModuleDevice(obj, device)
if type(obj) == 'table' and not torch.isTypeOf(obj, 'nn.GPU') and not obj.__noGPU__ then
for k,v in pairs(obj) do
obj[k] = GPU.recursiveModuleDevice(v, device)
end
elseif torch.type(obj):match('torch.Cuda.*Tensor') then
if obj:getDevice() ~= device then
obj = obj:clone() -- this will reallocate it to device
local newdevice = obj:getDevice()
-- when nElement() == 0 newdevice is 0
assert(newdevice == device or newdevice == 0)
end
end
assert(obj ~= nil)
return obj
end

-- set the device of the decorated module
function GPU:setDevice(device)
self.device = device or self.device

assert(self.modules[1])
self.modules[1] = cutorch.withDevice(self.device, function()
return self.recursiveModuleDevice(self.modules[1], self.device)
end)
return self
end

-- when proto is a device number, returns a dst that has device device for each element in src
-- otherwise, if proto is a table/tensor, makes sure dst is a identical to src, yet on the same device as proto
function GPU.recursiveSetDevice(dst, src, proto)
local device, prototable
if torch.isTensor(proto) then
device = proto:getDevice()
elseif torch.type(proto) == 'number' then
device = proto
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else... error

elseif torch.type(proto) == 'table' then
prototable = true
else
error"Expecting number, table or tensor for arg 3 (proto)"
end
if torch.type(src) == 'table' then
dst = torch.type(dst) == 'table' and dst or {}
for k,v in ipairs(src) do
dst[k] = GPU.recursiveSetDevice(dst[k], v, prototable and proto[k] or device)
end
for k=#src+1,#dst do
dst[k] = nil
end
elseif torch.type(src):match('torch.Cuda.*Tensor') and src:getDevice() ~= device and src:getDevice() ~= 0 then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole function can just call GPU.recursiveSetDevice above after figuring out the device id in line 79-84. Lots of duplication

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. They used to be more different. Refactored in latest commit.

if not (torch.type(dst):match('torch.Cuda.*Tensor') and dst:getDevice() == device) then
dst = src.new()
end
cutorch.withDevice(device, function() dst:resizeAs(src):copy(src) end)
else
dst = src
end
return dst
end

function GPU:updateOutput(input)
if self._type == 'torch.CudaTensor' then
self._input = self.recursiveSetDevice(self._input, input, self.device)

local output = cutorch.withDevice(self.device, function()
return self.modules[1]:updateOutput(self._input)
end)

if self.device ~= self.outdevice then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code never reaches this whole block. it has a return statement in line 108 (two lines above)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh nevermind about this comment. my bad.

self.output = self.recursiveSetDevice(self.output, output, self.outdevice)
else
self.output = output
end
else
self.output = self.modules[1]:updateOutput(input)
end

return self.output
end

function GPU:updateGradInput(input, gradOutput)
if self._type == 'torch.CudaTensor' then
self._gradOutput = self.recursiveSetDevice(self._gradOutput, gradOutput, self.device)

local gradInput = cutorch.withDevice(self.device, function()
return self.modules[1]:updateGradInput(self._input, self._gradOutput)
end)

self.gradInput = self.recursiveSetDevice(self.gradInput, gradInput, input)
else
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
end

return self.gradInput
end

function GPU:accGradParameters(input, gradOutput, scale)
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function()
self.modules[1]:accGradParameters(self._input, self._gradOutput, scale)
end)
else
self.modules[1]:accGradParameters(input, gradOutput, scale)
end
end

function GPU:apply(callback)
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.apply(self, callback) end)
else
parent.apply(self, callback)
end
end

function GPU:type(type, typecache)
if type and type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.type(self, type, typecache) end)
self:setDevice()
else
self.output = nil
self.gradInput = nil
self._input = nil
self._gradOutput = nil
parent.type(self, type, typecache)
end
return self
end

function GPU:clearState()
nn.utils.clear(self, 'output', 'gradInput')
self._input = nil
self._gradOutput = nil
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.clearState(self) end)
else
parent.clearState(self)
end
end

function GPU:zeroGradParameters()
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.zeroGradParameters(self) end)
else
parent.zeroGradParameters(self)
end
end

function GPU:updateParameters(lr)
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.updateParameters(self, lr) end)
else
parent.updateParameters(self, lr)
end
end

function GPU:training()
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.training(self) end)
else
parent.training(self)
end
end

function GPU:evaluate()
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.evaluate(self) end)
else
parent.evaluate(self)
end
end

function GPU:share(mlp, ...)
local args = {...}
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.share(self, mlp, unpack(args)) end)
else
parent.share(self, mlp, unpack(args))
end
return self
end

function GPU:reset(...)
local args = {...}
if self._type == 'torch.CudaTensor' then
cutorch.withDevice(self.device, function() parent.reset(self, unpack(args)) end)
else
parent.reset(self, unpack(args))
end
return self
end

function GPU:clone(...)
local args = {...}
if self._type == 'torch.CudaTensor' then
return cutorch.withDevice(self.device, function() parent.clone(self, unpack(args)) end)
else
return parent.clone(self, unpack(args))
end
end

function GPU:write(file)
-- Write all values in the object as a table.
local object = {}
for k, v in pairs(self) do
object[k] = v
end
local header = {self._type, self.device}
file:writeObject(header)
file:writeObject(object)
end

function GPU:read(file)
local header = file:readObject()
local object
if header[1] == 'torch.CudaTensor' then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be really useful / worth it here to check if total number of devices <= header[2].
If there are less devices than header[2], then print a WARNING and fallback to device[0].
That way, the model can be deserialized on a 1-GPU machine, though it was created on a 4-GPU machine, for example.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in latest commit.

local device = header[2]
if device > cutorch.getDeviceCount() then
print"Warning : model was saved with more devices than available on current host."
print"Attempting to load module onto device 1"
device = 1
end
object = cutorch.withDevice(device, function() return file:readObject() end)
else
object = file:readObject()
end

for k, v in pairs(object) do
self[k] = v
end
end

function GPU:__tostring__()
if self.modules[1].__tostring__ then
return torch.type(self) .. '(' .. self.device ..') @ ' .. self.modules[1]:__tostring__()
else
return torch.type(self) .. '(' .. self.device ..') @ ' .. torch.type(self.modules[1])
end
end

function GPU:accUpdateGradParameters(input, gradOutput, lr)
error("Not Implemented for "..torch.type(self))
end

function GPU:sharedAccUpdateGradParameters(input, gradOutput, lr)
error("Not Implemented for "..torch.type(self))
end
48 changes: 48 additions & 0 deletions doc/simple.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [Padding](#nn.Padding) : adds padding to a dimension ;
* [L1Penalty](#nn.L1Penalty) : adds an L1 penalty to an input (for sparsity) ;
* [GradientReversal](#nn.GradientReversal) : reverses the gradient (to maximize an objective function) ;
* [GPU](#nn.GPU) : decorates a module so that it can be executed on a specific GPU device.

<a name="nn.Linear"></a>
## Linear ##
Expand Down Expand Up @@ -1357,3 +1358,50 @@ One can also call:
module:setLambda(lambda)
```
to set the hyper-parameter `lambda` dynamically during training.

<a name="nn.GPU"></a>
## GPU ##

```lua
gpu = nn.GPU(module, device, [outdevice])
require 'cunn'
gpu:cuda()
```

Decorates an encapsulated `module` so that it can be executed on a specific GPU `device`.
The decorated module's `parameters` are thus hosted on the specified GPU `device`.
All operations on the `gpu` module are executed on that device.
Calls to `forward`/`backward` will transfer arguments `input` and `gradOutput` to the specified `device`,
which are then fed as arguments to the decorated `module`.
Returned `output` is located on the specified `outdevice` (defaults to `device`).
Returned `gradInput` is allocated on the same device as the `input`.

When serialized/deserialized, the `gpu` module will be run on the same `device` that it was serialized with.
To prevent this from happening, the module can be converted to float/double before serialization:

```lua
gpu:float()
gpustr = torch.serialize(gpu)
```

The module is located in the __nn__ package instead of __cunn__ as this allows
it to be used in CPU-only enviroments, which are common for production models.

The module supports nested table `input` and `gradOutput` tensors originating from multiple devices.
Each nested tensor in the returned `gradInput` will be transfered to the device its commensurate tensor in the `input`.

The intended use-case is not for model-parallelism where the models are executed in parallel on multiple devices, but
for sequential models where a single GPU doesn't have enough memory.

Example using 4 GPUs:

```lua
mlp = nn.Sequential()
:add(nn.GPU(nn.Linear(10000,10000), 1))
:add(nn.GPU(nn.Linear(10000,10000), 2))
:add(nn.GPU(nn.Linear(10000,10000), 3))
:add(nn.GPU(nn.Linear(10000,10000), 4, cutorch.getDevice()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if this line can run in CPU-only environments

Copy link
Member Author

@nicholas-leonard nicholas-leonard May 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cutorch.getDevice will not. But it isn't mandatory to use cutorch.getDevice(), this is just an example.

```

Note how the last `GPU` instance will return an `output` tensor on the same device as the current device (`cutorch.getDevice`).

2 changes: 2 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ require('nn.VolumetricMaxUnpooling')
require('nn.VolumetricAveragePooling')
require('nn.VolumetricBatchNormalization')

require('nn.GPU')

require('nn.ParallelTable')
require('nn.Identity')
require('nn.ConcatTable')
Expand Down
5 changes: 5 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6228,6 +6228,11 @@ function nntest.ErrorHandling()
)
end

function nntest.GPU()
-- this is a placeholder to let you know that the nn.GPU unit test
-- is located in cunn package.
end

mytester:add(nntest)

jac = nn.Jacobian
Expand Down