Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should dp.TextSource provide an impl of frequencyTree()? #154

Open
hsheil opened this issue Aug 20, 2015 · 1 comment
Open

Should dp.TextSource provide an impl of frequencyTree()? #154

hsheil opened this issue Aug 20, 2015 · 1 comment

Comments

@hsheil
Copy link

hsheil commented Aug 20, 2015

When running recurrentlanguagemodel.lua with a custom text dataset and --softmaxtree the following error occurs:

/home/hsheil/torch/install/bin/luajit: recurrentlanguagemodel.lua:222: attempt to call method 'frequencyTree' (a nil value)
stack traceback:
    recurrentlanguagemodel.lua:222: in main chunk
    [C]: in function 'dofile'
    ...heil/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:131: in main chunk
    [C]: at 0x00406670

Line 222 is the culprit:

   elseif opt.softmaxtree then -- uses frequency based tree
      local tree, root = ds:frequencyTree()
      softmax = nn.SoftMaxTree(inputSize, tree, root, opt.accUpdate)
   end

I guess --softmaxtree is most beneficial for the billion word dataset anyway - but to avoid this error should dp.TextSource provide an impl of the frequencyTree() method?

Cmd-line params to reproduce (only the combination of --softmaxtree and --dataset TextSource is pertinent I think):

th recurrentlanguagemodel.lua --lstm --cuda --dataPath data --dataset TextSource --softmaxtree

@adonisues
Copy link

I had same problem. It seems that there is no initialization of softmaxtree.
So I have solved it as following

add below( it is from dp/data/penntreebank.lua) to dp/data/textsource.lua

-- this can be used to initialize a SoftMaxTree
function TextSource:frequencyTree(binSize)
binSize = binSize or 100
local wf = torch.IntTensor(self:wordFrequency())
local vals, indices = wf:sort()
local tree = {}
local id = indices:size(1)
function recursiveTree(indices)
if indices:size(1) < binSize then
id = id + 1
tree[id] = indices
return
end
local parents = {}
for start=1,indices:size(1),binSize do
local stop = math.min(indices:size(1), start+binSize-1)
local bin = indices:narrow(1, start, stop-start+1)
assert(bin:size(1) <= binSize)
id = id + 1
table.insert(parents, id)
tree[id] = bin
end
recursiveTree(indices.new(parents))
end
recursiveTree(indices)
return tree, id
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants