-
Notifications
You must be signed in to change notification settings - Fork 967
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e764241
commit db7b972
Showing
4 changed files
with
329 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
------------------------------------------------------------------------ | ||
--[[ GPU ]]-- | ||
-- Decorates a module such that its parameters are | ||
-- hosted on a specified GPU device. | ||
-- The operations are also executed on that device. | ||
-- Arguments input and gradOutput are converted to the specified device | ||
-- before being fed to the decorated module. | ||
-- Returned output is on the specified outdevice (defaults to device). | ||
-- Returned gradInput is allocated on the same device as the input. | ||
-- The unit test is located in cunn. | ||
------------------------------------------------------------------------ | ||
local GPU, parent = torch.class("nn.GPU", "nn.Container") | ||
|
||
function GPU:__init(module, device, outdevice) | ||
parent.__init(self) | ||
assert(torch.type(device) == 'number') | ||
self.device = device | ||
self.outdevice = outdevice or device | ||
|
||
assert(torch.isTypeOf(module, 'nn.Module')) | ||
self.modules[1] = module | ||
|
||
if module:type() == 'torch.CudaTensor' then | ||
self:cuda() | ||
end | ||
end | ||
|
||
function GPU.recursiveModuleDevice(obj, device) | ||
if type(obj) == 'table' and not torch.isTypeOf(obj, 'nn.GPU') then | ||
for k,v in pairs(obj) do | ||
obj[k] = GPU.recursiveModuleDevice(v, device) | ||
end | ||
elseif torch.type(obj):match('torch.Cuda.*Tensor') then | ||
if obj:getDevice() ~= device then | ||
obj = obj:clone() -- this will reallocate it to device | ||
local newdevice = obj:getDevice() | ||
-- when nElement() == 0 newdevice is 0 | ||
assert(newdevice == device or newdevice == 0) | ||
end | ||
end | ||
assert(obj ~= nil) | ||
return obj | ||
end | ||
|
||
-- set the device of the decorated module | ||
function GPU:setDevice(device) | ||
self.device = device or self.device | ||
|
||
assert(self.modules[1]) | ||
self.modules[1] = cutorch.withDevice(self.device, function() | ||
return self.recursiveModuleDevice(self.modules[1], self.device) | ||
end) | ||
return self | ||
end | ||
|
||
-- returns a dst that has device device for each element in src | ||
function GPU.recursiveSetDevice(dst, src, device) | ||
if torch.type(src) == 'table' then | ||
dst = torch.type(dst) == 'table' and dst or {} | ||
for k,v in ipairs(src) do | ||
dst[k] = GPU.recursiveSetDevice(dst[k], v, device) | ||
end | ||
for k=#src+1,#dst do | ||
dst[k] = nil | ||
end | ||
elseif torch.type(src):match('torch.Cuda.*Tensor') and src:getDevice() ~= device and src:getDevice() ~= 0 then | ||
if not (torch.type(dst):match('torch.Cuda.*Tensor') and dst:getDevice() == device) then | ||
dst = src.new() | ||
end | ||
dst:resizeAs(src):copy(src) | ||
else | ||
dst = src | ||
end | ||
return dst | ||
end | ||
|
||
-- makes sure dst is a identical to src except but on the same device as proto | ||
function GPU.recursiveSetDeviceAs(dst, src, proto) | ||
local device | ||
if torch.isTensor(proto) then | ||
device = proto:getDevice() | ||
elseif torch.type(proto) == 'number' then | ||
device = proto | ||
end | ||
if torch.type(src) == 'table' then | ||
dst = torch.type(dst) == 'table' and dst or {} | ||
for k,v in ipairs(src) do | ||
dst[k] = GPU.recursiveSetDeviceAs(dst[k], v, proto[k]) | ||
end | ||
for k=#src+1,#dst do | ||
dst[k] = nil | ||
end | ||
elseif torch.type(src):match('torch.Cuda.*Tensor') and src:getDevice() ~= device and src:getDevice() ~= 0 then | ||
if not (torch.type(dst):match('torch.Cuda.*Tensor') and dst:getDevice() == device) then | ||
dst = src.new() | ||
end | ||
cutorch.withDevice(device, function() dst:resizeAs(src):copy(src) end) | ||
else | ||
dst = src | ||
end | ||
return dst | ||
end | ||
|
||
function GPU:updateOutput(input) | ||
if self._type == 'torch.CudaTensor' then | ||
local output = cutorch.withDevice(self.device, function() | ||
self._input = self.recursiveSetDevice(self._input, input, self.device) | ||
return self.modules[1]:updateOutput(self._input) | ||
end) | ||
|
||
if self.device ~= self.outdevice then | ||
self.output = cutorch.withDevice(self.outdevice, function() | ||
return self.recursiveSetDevice(self.output, output, self.outdevice) | ||
end) | ||
else | ||
self.output = output | ||
end | ||
else | ||
self.output = self.modules[1]:updateOutput(input) | ||
end | ||
|
||
return self.output | ||
end | ||
|
||
function GPU:updateGradInput(input, gradOutput) | ||
if self._type == 'torch.CudaTensor' then | ||
local gradInput = cutorch.withDevice(self.device, function() | ||
self._gradOutput = self.recursiveSetDevice(self._gradOutput, gradOutput, self.device) | ||
return self.modules[1]:updateGradInput(self._input, self._gradOutput) | ||
end) | ||
|
||
self.gradInput = self.recursiveSetDeviceAs(self.gradInput, gradInput, input) | ||
else | ||
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) | ||
end | ||
|
||
return self.gradInput | ||
end | ||
|
||
function GPU:accGradParameters(input, gradOutput, scale) | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() | ||
self.modules[1]:accGradParameters(self._input, self._gradOutput, scale) | ||
end) | ||
else | ||
self.modules[1]:accGradParameters(input, gradOutput, scale) | ||
end | ||
end | ||
|
||
function GPU:apply(callback) | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.apply(self, callback) end) | ||
else | ||
parent.apply(self, callback) | ||
end | ||
end | ||
|
||
function GPU:type(type, typecache) | ||
if type and type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.type(self, type, typecache) end) | ||
self:setDevice() | ||
else | ||
self.output = nil | ||
self.gradInput = nil | ||
self._input = nil | ||
self._gradOutput = nil | ||
parent.type(self, type, typecache) | ||
end | ||
return self | ||
end | ||
|
||
function GPU:clearState() | ||
self.output = nil | ||
self.gradInput = nil | ||
self._input = nil | ||
self._gradOutput = nil | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.clearState(self) end) | ||
else | ||
parent.clearState(self) | ||
end | ||
end | ||
|
||
function GPU:zeroGradParameters() | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.zeroGradParameters(self) end) | ||
else | ||
parent.zeroGradParameters(self) | ||
end | ||
end | ||
|
||
function GPU:updateParameters(lr) | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.updateParameters(self, lr) end) | ||
else | ||
parent.updateParameters(self, lr) | ||
end | ||
end | ||
|
||
function GPU:training() | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.training(self) end) | ||
else | ||
parent.training(self) | ||
end | ||
end | ||
|
||
function GPU:evaluate() | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.evaluate(self) end) | ||
else | ||
parent.evaluate(self) | ||
end | ||
end | ||
|
||
function GPU:share(mlp, ...) | ||
local args = {...} | ||
if self._type == 'torch.CudaTensor' then | ||
cutorch.withDevice(self.device, function() parent.share(self, mlp, unpack(args)) end) | ||
else | ||
parent.share(self, mlp, unpack(args)) | ||
end | ||
return self | ||
end | ||
|
||
function GPU:clone(...) | ||
local args = {...} | ||
if self._type == 'torch.CudaTensor' then | ||
return cutorch.withDevice(self.device, function() parent.clone(self, unpack(args)) end) | ||
else | ||
return parent.clone(self, unpack(args)) | ||
end | ||
end | ||
|
||
function GPU:write(file) | ||
-- Write all values in the object as a table. | ||
local object = {} | ||
for k, v in pairs(self) do | ||
object[k] = v | ||
end | ||
local header = {self._type, self.device} | ||
file:writeObject(header) | ||
file:writeObject(object) | ||
end | ||
|
||
function GPU:read(file) | ||
local header = file:readObject() | ||
local object | ||
if header[1] == 'torch.CudaTensor' then | ||
object = cutorch.withDevice(header[2], function() return file:readObject() end) | ||
else | ||
object = file:readObject() | ||
end | ||
|
||
for k, v in pairs(object) do | ||
self[k] = v | ||
end | ||
end | ||
|
||
function GPU:__tostring__() | ||
if self.modules[1].__tostring__ then | ||
return torch.type(self) .. '(' .. self.device ..') @ ' .. self.modules[1]:__tostring__() | ||
else | ||
return torch.type(self) .. '(' .. self.device ..') @ ' .. torch.type(self.modules[1]) | ||
end | ||
end | ||
|
||
function GPU:accUpdateGradParameters(input, gradOutput, lr) | ||
error"Not Implemented" | ||
end | ||
|
||
function GPU:sharedAccUpdateGradParameters(input, gradOutput, lr) | ||
error"Not Implemented" | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters