Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-variate output #23

Closed
dsweber2 opened this issue Oct 29, 2015 · 11 comments
Closed

Multi-variate output #23

dsweber2 opened this issue Oct 29, 2015 · 11 comments

Comments

@dsweber2
Copy link

So I'm looking for a way to calculate the total derivative in an AD manner, and was somewhat surprised to find that rdiff requires that the input function have only one output variable:

function fun(x)
  return [x[2],x[1]]
end

rdiff(fun, ([1,1],))

returns output variable should be Real, Array{Int64,1} found.

I can see how for order >= 2, it's best to require that the output of the function be a single variable. However, for order 1 it seems perfectly reasonable to look for the total derivative. As far as ways to calculate the total derivative with the code as is, I suppose defining an expression for each coordinate of the function should work, and then stiching it back together using @eval, but this seems kludgy to me. Suggestions?

@mlubin
Copy link
Contributor

mlubin commented Oct 29, 2015

There's typically little algorithmic benefit from applying reverse mode to a multivariate function versus applying it to a multiple univariate functions that represent the output values of the original function. If you use forward-mode AD then the complexity depends on the input dimension but not the output dimension.

@dsweber2
Copy link
Author

Thanks for the fast turn around!
So algorithmically, there's little difference between the multivariate and univariate cases? But from a coding perspective, it would be convenient to be able to put a multivariate function as input to rdiff.

As I'm relatively new to Julia, is there a simple way to feed a multivariate function through rdiff that I'm just not seeing (I tried using :(fun(x)[1]), & it didn't work)?

@mlubin
Copy link
Contributor

mlubin commented Oct 29, 2015

Yes, I understand that it's not very convenient to split a multivariate output function into multiple univariate output functions. I don't think that's implemented, I'll let @fredo-dedup comment here.

@fredo-dedup
Copy link
Contributor

Hi all,
I think this is a duplicate of #11. See the (kludgy, sorry) solution proposed there.
This one is marked as a needed enhancement. Now I just need the time to implement this someday.

@dsweber2
Copy link
Author

Oh, I didn't notice how similar that one was, sorry for the redundancy. The suggestion there doesn't seem to work on functions however:

rdiff(fun[1], ([2,1],), order=1)
'getindex' has no method matching getindex(::Function, ::Int64)

I also tried wrapping an expression around it.

rdiff(:(fun(x)[1]),x=[2,1])
no derivation rule for fun at arg #1

I'm guessing there's a metaprogramming solution, but I don't have much experience with that yet. Suggestions?

@fredo-dedup
Copy link
Contributor

Sorry for not providing a more detailed workaround !
If you are working with an expression, say for the example :( log(x) ), just amend the expression with [i] : :( log(x)[i] ), i being a variable that you have predefined.
Run ex = rdiff( :( log(x)[i] ), x=[1,2]), the resulting expression depends on both i and x
Wrap the result in a loop

@eval function mex(x) 
        res = Array(Float64, length(x), length(x)) # will hold the result
        for i in 1:length(x)
          _ , res[i,:] = $ex
        end
        res
      end

Executing mex(ones(5)) should now work as expected. Note that the code will be super inefficient, if possible you can edit ex and move out of the loop the calculations that do not depend on i.

@dsweber2
Copy link
Author

this works fine for for functions where the derivative is already defined, such as log, but for ones such as fun above, rdiff doesn't go into the function (I tried this above).

@fredo-dedup
Copy link
Contributor

This is correct, rdiff does not look into the definition of functions, either it is a function with pre-declared derivation rules (ReverseDiffSource comes with definitions for the most basic functions) or you tell rdiff how to derive the function (see doc on the @deriv_rule macro).
If fun is a function that you designed, perhaps you can add [i] at the end, and throw the expression to rdiff ?
If it is a pre-defined function, then you'll have to declare how it is derivated with @deriv_rule.

I can try to take a look if you tell me more about your fun().

@dsweber2
Copy link
Author

dsweber2 commented Nov 1, 2015

I've just been using the function defined above

function fun(x)
  return [x[2],x[1]]
end

as a test case (I have several functions that I would like to be able to run this on).
I'm not sure what you mean by adding [i] at the end; rdiff(fun[i]) getindex throws an error.

@fredo-dedup
Copy link
Contributor

The [i] should be within the function or the expression. (note that the function version of rdiff seems a bit buggy with your example so I'll be using an expression below) :

exfun = quote
  tmp = [x[2],x[1]]
  tmp[i]  # the getindex is placed within the expression
end

i = 1 # i needs to set for rdiff to work
dex = rdiff(exfun, x=ones(2))

@eval function mex(x) 
        res = Array(Float64, length(x), length(x)) # will hold the result
        for i in 1:length(x)
          _ , res[i,:] = $dex
        end
        res
      end

mex([0.,1])  

@fredo-dedup
Copy link
Contributor

Closing because it is a duplicate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants