-
Notifications
You must be signed in to change notification settings - Fork 967
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
------------------------------------------------------------------------ | ||
--[[ GPU ]]-- | ||
-- Decorates a module such that its parameters are | ||
-- hosted on a specified GPU device. | ||
-- The operations are also executed on that device. | ||
-- Arguments input and gradOutput are converted to the specified device | ||
-- before being fed to the decorated module. | ||
-- Returned output is on the specified outdevice (defaults to device). | ||
-- Returned gradInput is allocated on the same device as the input. | ||
-- The unit test is located in cunn. | ||
------------------------------------------------------------------------ | ||
local GPU, parent = torch.class("nn.GPU", "nn.Container") | ||
|
||
function GPU:__init(module, device, outdevice) | ||
parent.__init(self) | ||
assert(torch.type(device) == 'number') | ||
self.device = device | ||
self.outdevice = outdevice or device | ||
|
||
assert(torch.isTypeOf(module, 'nn.Module')) | ||
self.modules[1] = module | ||
|
||
if module:type() == 'torch.CudaTensor' then | ||
self:cuda() | ||
end | ||
end | ||
|
||
function GPU.recursiveModuleDevice(obj, device) | ||
if type(obj) == 'table' and not torch.isTypeOf(obj, 'nn.GPU') and not obj.__noGPU__ then | ||
for k,v in pairs(obj) do | ||
obj[k] = GPU.recursiveModuleDevice(v, device) | ||
end | ||
elseif torch.type(obj):match('torch.Cuda.*Tensor') then | ||
if obj:getDevice() ~= device then | ||
obj = obj:clone() -- this will reallocate it to device | ||
local newdevice = obj:getDevice() | ||
-- when nElement() == 0 newdevice is 0 | ||
assert(newdevice == device or newdevice == 0) | ||
end | ||
end | ||
assert(obj ~= nil) | ||
return obj | ||
end | ||
|
||
-- set the device of the decorated module | ||
function GPU:setDevice(device) | ||
self.device = device or self.device | ||
|
||
assert(self.modules[1]) | ||
self.modules[1] = cutorch.withDevice(self.device, function() | ||
return self.recursiveModuleDevice(self.modules[1], self.device) | ||
end) | ||
return self | ||
end | ||
|
||
-- when proto is a device number, returns a dst that has device device for each element in src | ||
-- otherwise, if proto is a table/tensor, makes sure dst is a identical to src, yet on the same device as proto | ||
function GPU.recursiveSetDevice(dst, src, proto) | ||
local device, prototable | ||
if torch.isTensor(proto) then | ||
device = proto:getDevice() | ||
elseif torch.type(proto) == 'number' then | ||
device = proto | ||
elseif torch.type(proto) == 'table' then | ||
prototable = true | ||
else | ||
error"Expecting number, table or tensor for arg 3 (proto)" | ||
end | ||
if torch.type(src) == 'table' then | ||
dst = torch.type(dst) == 'table' and dst or {} | ||
for k,v in ipairs(src) do | ||
dst[k] = GPU.recursiveSetDevice(dst[k], v, prototable and proto[k] or device) | ||
end | ||
for k=#src+1,#dst do | ||
dst[k] = nil | ||
end | ||
elseif torch.type(src):match('torch.Cuda.*Tensor') and src:getDevice() ~= device and src:getDevice() ~= 0 then | ||
if not (torch.type(dst):match('torch.Cuda.*Tensor') and dst:getDevice() == device) then | ||
dst = src.new() | ||
end | ||
cutorch.withDevice(device, function() dst:resizeAs(src):copy(src) end) | ||
else | ||
dst = src | ||
end | ||
return dst | ||
end | ||
|
||
function GPU:updateOutput(input) | ||
if self._type == 'torch.CudaTensor' then | ||
self._input = self.recursiveSetDevice(self._input, input, self.device) | ||
|
||
local output = cutorch.withDevice(self.device, function() | ||
return self.modules[1]:updateOutput(self._input) | ||
end) | ||
|
||
if self.device ~= self.outdevice then | ||
self.output = self.recursiveSetDevice(self.output, output, self.outdevice) | ||
else | ||
self.output = output | ||
end | ||
else | ||
self.output = self.modules[1]:updateOutput(input) | ||
end | ||
|
||
return self.output | ||
end | ||
|
||
function GPU:updateGradInput(input, gradOutput) | ||
if self._type == 'torch.CudaTensor' then | ||
self._gradOutput = self.recursiveSetDevice(self._gradOutput, gradOutput, self.device) | ||
|
||
local gradInput = cutorch.withDevice(self.device, function() | ||
return self.modules[1]:updateGradInput(self._input, self._gradOutput) | ||
end) | ||
|
||
self.gradInput = self.recursiveSetDevice(self.gradInput, gradInput, input) | ||
else | ||
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) | ||
end | ||
|
||
return self.gradInput | ||
end | ||
|
||
function GPU:accGradParameters(input, gradOutput, scale) | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() | ||
self.modules[1]:accGradParameters(self._input, self._gradOutput, scale) | ||
end) | ||
else | ||
self.modules[1]:accGradParameters(input, gradOutput, scale) | ||
end | ||
end | ||
|
||
function GPU:apply(callback) | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.apply(self, callback) end) | ||
else | ||
parent.apply(self, callback) | ||
end | ||
end | ||
|
||
function GPU:type(type, typecache) | ||
if type and type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.type(self, type, typecache) end) | ||
self:setDevice() | ||
else | ||
self.output = nil | ||
self.gradInput = nil | ||
self._input = nil | ||
self._gradOutput = nil | ||
parent.type(self, type, typecache) | ||
end | ||
return self | ||
end | ||
|
||
function GPU:clearState() | ||
nn.utils.clear(self, 'output', 'gradInput') | ||
self._input = nil | ||
self._gradOutput = nil | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.clearState(self) end) | ||
else | ||
parent.clearState(self) | ||
end | ||
end | ||
|
||
function GPU:zeroGradParameters() | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.zeroGradParameters(self) end) | ||
else | ||
parent.zeroGradParameters(self) | ||
end | ||
end | ||
|
||
function GPU:updateParameters(lr) | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.updateParameters(self, lr) end) | ||
else | ||
parent.updateParameters(self, lr) | ||
end | ||
end | ||
|
||
function GPU:training() | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.training(self) end) | ||
else | ||
parent.training(self) | ||
end | ||
end | ||
|
||
function GPU:evaluate() | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.evaluate(self) end) | ||
else | ||
parent.evaluate(self) | ||
end | ||
end | ||
|
||
function GPU:share(mlp, ...) | ||
local args = {...} | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.share(self, mlp, unpack(args)) end) | ||
else | ||
parent.share(self, mlp, unpack(args)) | ||
end | ||
return self | ||
end | ||
|
||
function GPU:reset(...) | ||
local args = {...} | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.reset(self, unpack(args)) end) | ||
else | ||
parent.reset(self, unpack(args)) | ||
end | ||
return self | ||
end | ||
|
||
function GPU:clone(...) | ||
local args = {...} | ||
if self._type == 'torch.CudaTensor' then | ||
return cutorch.withDevice(self.device, function() parent.clone(self, unpack(args)) end) | ||
else | ||
return parent.clone(self, unpack(args)) | ||
end | ||
end | ||
|
||
function GPU:write(file) | ||
-- Write all values in the object as a table. | ||
local object = {} | ||
for k, v in pairs(self) do | ||
object[k] = v | ||
end | ||
local header = {self._type, self.device} | ||
file:writeObject(header) | ||
file:writeObject(object) | ||
end | ||
|
||
function GPU:read(file) | ||
local header = file:readObject() | ||
local object | ||
if header[1] == 'torch.CudaTensor' then | ||
local device = header[2] | ||
if device > cutorch.getDeviceCount() then | ||
print"Warning : model was saved with more devices than available on current host." | ||
print"Attempting to load module onto device 1" | ||
device = 1 | ||
end | ||
object = cutorch.withDevice(device, function() return file:readObject() end) | ||
else | ||
object = file:readObject() | ||
end | ||
|
||
for k, v in pairs(object) do | ||
self[k] = v | ||
end | ||
end | ||
|
||
function GPU:__tostring__() | ||
if self.modules[1].__tostring__ then | ||
return torch.type(self) .. '(' .. self.device ..') @ ' .. self.modules[1]:__tostring__() | ||
else | ||
return torch.type(self) .. '(' .. self.device ..') @ ' .. torch.type(self.modules[1]) | ||
end | ||
end | ||
|
||
function GPU:accUpdateGradParameters(input, gradOutput, lr) | ||
error("Not Implemented for "..torch.type(self)) | ||
end | ||
|
||
function GPU:sharedAccUpdateGradParameters(input, gradOutput, lr) | ||
error("Not Implemented for "..torch.type(self)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters