Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Matrix Multiplication #291

Open
CGMossa opened this issue May 5, 2023 · 6 comments
Open

Support Matrix Multiplication #291

CGMossa opened this issue May 5, 2023 · 6 comments

Comments

@CGMossa
Copy link

CGMossa commented May 5, 2023

Benchmarking an R+deSolve code against the equivalent odin code yielded a surprising result:

# A tibble: 2 × 13
  expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory     time      
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>     <list>    
1 odin_c        303ms    305ms      3.28    8.59MB      0       2     0      610ms <NULL> <Rprofmem> <bench_tm>
2 deSolve_r     121ms    121ms      8.23   38.04MB     24.7     1     3      121ms <NULL> <Rprofmem> <bench_tm>
# ℹ 1 more variable: gc <list>

I suspect the culprit is the lack of matrix multiplication in odin (or maybe I don't know how to invoke it).
In the deSolve part:

between_sites <- transmission_rate * S * (foi_matrix %*% I)

However, in the odin part I do:

foi[, ] <- transmission_rate * foi_mat[i, j] * I[j]
delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,])
  deriv(S[]) <- -delta_transmission[i]
  deriv(I[]) <- +delta_transmission[i] - delta_recovery[i]

Here I've omitted the parts I don't think are necessary.

  • Is matrix multiplication supported?
    Unfortunately, R's C-facilities have a weird way of doing matrix multiplication, so it might not be supported yet.

I'll work on a minimal testcase to check if this is indeed the problem.

Under details, I have more complete excerpts of my code:

Details

odin::odin({
  foi[, ] <- foi_mat[i, j] * I[j]
  delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,])
  delta_recovery[] <- recovery_rate * I[i]
  deriv(S[]) <- -delta_transmission[i]
  deriv(I[]) <- +delta_transmission[i] - delta_recovery[i]
  deriv(R[]) <- delta_recovery[i]

  source_id <- user()
  target_id <- user()
  Total[] <- S[i] + I[i] + R[i]
  output(source_prevalence) <- I[as.integer(source_id)] / Total[as.integer(source_id)]
  output(target_prevalence) <- I[as.integer(target_id)] / Total[as.integer(target_id)]

  transmission_rate <- user(0.05)
  recovery_rate <- user(0.01)
  S0[] <- user()
  I0[] <- user()
  foi_mat[,] <- user()
  initial(S[]) <- S0[i]
  initial(I[]) <- I0[i]
  initial(R[]) <- Total[i] - S0[i] - I0[i]
  dim(S0) <- user()
  dim(S) <- N
  dim(I) <- N
  dim(R) <- N
  dim(I0) <- N
  dim(foi_mat) <- c(N, N)
  dim(foi) <- c(N, N)
  dim(Total) <- N
  dim(delta_transmission) <- N
  dim(delta_recovery) <- N
  N <- length(S0)
},
verbose = TRUE, validate = TRUE, target = "c", pretty = TRUE,
skip_cache = FALSE) ->
  model_generator

model <-
  model_generator$new(S0 = site_S,
                      I0 = site_I,
                      foi_mat = foi_matrix,
                      source_id = as.integer(source_site_id),
                      target_id = as.integer(target_site_id))
model$set_user(transmission_rate = 0.05, recovery_rate = 0.01)

Benchmarking:

bench::mark(
  odin_c = model$run(0:100),
  deSolve_r = {
    
    site_R <- site_S
    site_R[] <- 0
    deSolve::ode(
      y = c(S = site_S, I = site_I, R = site_R),
      times = 0:100,
      func = function(time, state, parms) {
        with(parms, {
          S <- state[1:N]
          I <- state[(N + 1):(2 * N)]
          R <- state[(2 * N + 1):(3 * N)]
          
          Total <- S + I + R
          
          between_sites <- transmission_rate * S * (foi_matrix %*% I)
          
          
          source_prevalence <- I[[source_id]] / Total[[source_id]]
          target_prevalence <- I[[target_id]] / Total[[target_id]]
          
          list(c(
            dS = -transmission_rate * S * I - between_sites,
            dI = +transmission_rate * S * I + between_sites - recovery_rate * I,
            dR = recovery_rate * I
          ),
          source_prevalence = source_prevalence,
          target_prevalence = target_prevalence)
        })
      },
      parms = list(
        transmission_rate = 0.05,
        recovery_rate = 0.01,
        source_id = as.integer(source_site_id),
        target_id = as.integer(target_site_id),
        N = length(site_S)
      )
    )
  },
  check = FALSE
) %>% 
  print()

@richfitz
Copy link
Member

richfitz commented May 5, 2023

unfortunately this is not that surprising - if the model is dominated by a matrix multiplication, then the version that uses a linear algebra library will be much faster.

Supporting this properly has been on the back burner for a long time (#38, #134, #213 - these mostly concern multinomial distributions but the syntactic issue in #134 is shared and is the primary blocker). The actual calling convention is not that bad, though it does mean that models need to have a working copy of gfortran to compile which is quite annoying in practice, particularly for people on macs

@CGMossa
Copy link
Author

CGMossa commented May 5, 2023

I'm glad you agree. For my use-case, I can circumvent this by being a little more clever about this. But to stick to this issue, and since you know this stuff already:

  • How come you cannot inherit the OS-specific settings for this stuff that R does for itself on these platforms?
    First, I would think (maybe naively) that you can use R CMD config to compile with the right flags on different platforms:
C:\Users\minin>R CMD config LAPACK_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRlapack

But if I just think about BLAS (whatever that is). First, it says:

 R packages that use these should have PKG_LIBS in src/Makevars include
   $(BLAS_LIBS) $(FLIBS)

So on my Windows machine it is

C:\Users\minin>R CMD config FLIBS
-lgfortran -lm -lquadmath

C:\Users\minin>R CMD config BLAS_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRblas

Then apparently dgemm is the Fortran routine that is supposed to do this,
I've copied the prototype/header:

/* DGEMM - perform one of the matrix-matrix operations    */
/* C := alpha*op( A )*op( B ) + beta*C */
BLAS_extern void
F77_NAME(dgemm)(const char *transa, const char *transb, const int *m,
		const int *n, const int *k, const double *alpha,
		const double *a, const int *lda,
		const double *b, const int *ldb,
		const double *beta, double *c, const int *ldc 
		FCLEN FCLEN);

Finally, I've asked ChatGPT about this and it suggested this code for invoking this:

SEXP matrix_mult(SEXP a, SEXP b) {
  SEXP result;
  int nrow_a = nrows(a);
  int ncol_a = ncols(a);
  int nrow_b = nrows(b);
  int ncol_b = ncols(b);

  if (ncol_a != nrow_b) {
    error("Matrix dimensions do not match for multiplication.");
    return R_NilValue;
  }

  PROTECT(result = allocMatrix(REALSXP, nrow_a, ncol_b));

  double alpha = 1.0;
  double beta = 0.0;
  F77_CALL(dgemm)("N", "N", &nrow_a, &ncol_b, &ncol_a, &alpha, REAL(a), &nrow_a,
                  REAL(b), &nrow_b, &beta, REAL(result), &nrow_a);

  UNPROTECT(1);
  return result;
}

I don't know where these "N" comes from.
But there are more than one of these, and this one is particularly matrix-matrix (while I apparently need matrix-vector).
Presumably it is those SEXPTYPEs that the differentiator.

I've googled and BLAS should be supported on Mac.
I don't know how that relates to LAPLACK, or where they are switched or changed.

Details

Usage: R CMD config [options] [VAR]

Get the value of a basic R configure variable VAR which must be among
those listed in the 'Variables' section below, or the header and
library flags necessary for linking a front-end against R.

Options:
  -h, --help            print short help message and exit
  -v, --version         print version info and exit
      --cppflags        print pre-processor flags required to compile a
                        C/C++ file as part of a front-end using R as a library
      --ldflags         print linker flags needed for linking a front-end
                        against the R library
      --no-user-files   ignore customization files under ~/.R
      --no-site-files   ignore site customization files under R_HOME/etc
      --all             print names and values of all variables below

Variables:
  AR            command to make static libraries
  BLAS_LIBS     flags needed for linking against external BLAS libraries
  CC            C compiler command
  CFLAGS        C compiler flags
  CC17          Ditto for the C17 or earlier compiler
  C17FLAGS
  CC23          Ditto for the C23 or later compiler
  C23FLAGS
  CPICFLAGS     special flags for compiling C code to be included in a
                shared library
  CPPFLAGS      C/C++ preprocessor flags, e.g. -I<dir> if you have
                headers in a nonstandard directory <dir>
  CXX           default compiler command for C++ code
  CXXFLAGS      compiler flags for CXX
  CXXPICFLAGS   special flags for compiling C++ code to be included in a
                shared library
  CXX11         compiler command for C++11 code
  CXX11STD      flag used with CXX11 to enable C++11 support
  CXX11FLAGS    further compiler flags for CXX11
  CXX11PICFLAGS
                special flags for compiling C++11 code to be included in
                a shared library
  CXX14         compiler command for C++14 code
  CXX14STD      flag used with CXX14 to enable C++14 support
  CXX14FLAGS    further compiler flags for CXX14
  CXX14PICFLAGS
                special flags for compiling C++14 code to be included in
                a shared library
  CXX17         compiler command for C++17 code
  CXX17STD      flag used with CXX17 to enable C++17 support
  CXX17FLAGS    further compiler flags for CXX17
  CXX17PICFLAGS
                special flags for compiling C++17 code to be included in
                a shared library
  CXX20         compiler command for C++20 code
  CXX20STD      flag used with CXX20 to enable C++20 support
  CXX20FLAGS    further compiler flags for CXX20
  CXX23         compiler command for C++23 code
  CXX23STD      flag used with CXX23 to enable C++23 support
  CXX23FLAGS    further compiler flags for CXX23
  CXX23PICFLAGS
                special flags for compiling C++23 code to be included in
                a shared library
  DYLIB_EXT     file extension (including '.') for dynamic libraries
  DYLIB_LD      command for linking dynamic libraries which contain
                object files from a C or Fortran compiler only
  DYLIB_LDFLAGS
                special flags used by DYLIB_LD
  FC            Fortran compiler command
  FFLAGS        fixed-form Fortran compiler flags
  FCFLAGS       free-form Fortran 9x compiler flags
  FLIBS         linker flags needed to link Fortran code
  FPICFLAGS     special flags for compiling Fortran code to be turned
                into a shared library
  JAR           Java archive tool command
  JAVA          Java interpreter command
  JAVAC         Java compiler command
  JAVAH         Java header and stub generator command
  JAVA_HOME     path to the home of Java distribution
  JAVA_LIBS     flags needed for linking against Java libraries
  JAVA_CPPFLAGS C preprocessor flags needed for compiling JNI programs
  LAPACK_LIBS   flags needed for linking against external LAPACK libraries
  LIBnn         location for libraries, e.g. 'lib' or 'lib64' on this platform
  LDFLAGS       linker flags, e.g. -L<dir> if you have libraries in a
                nonstandard directory <dir>
  LTO LTO_FC LTO_LD  flags for Link-Time Optimization
  MAKE          Make command
  NM            comand to display symbol tables
  OBJC          Objective C compiler command
  OBJCFLAGS     Objective C compiler flags
  RANLIB        command to index static libraries
  SAFE_FFLAGS   Safe (as conformant as possible) Fortran compiler flags
  SHLIB_CFLAGS  additional CFLAGS used when building shared objects
  SHLIB_CXXFLAGS
                additional CXXFLAGS used when building shared objects
  SHLIB_CXXLD   command for linking shared objects which contain
                object files from a C++ compiler (and CXX11 CXX14 CXX17 CXX20 CXX23)
  SHLIB_CXXLDFLAGS
                special flags used by SHLIB_CXXLD (and CXX11 CXX14 CXX17 CXX20 CXX23)
  SHLIB_EXT     file extension (including '.') for shared objects
  SHLIB_FFLAGS  additional FFLAGS used when building shared objects
  SHLIB_LD      command for linking shared objects which contain
                object files from a C or Fortran compiler only
  SHLIB_LDFLAGS
                special flags used by SHLIB_LD
  TCLTK_CPPFLAGS
                flags needed for finding the tcl.h and tk.h headers
  TCLTK_LIBS    flags needed for linking against the Tcl and Tk libraries

Windows only:
  COMPILED_BY   name and version of compiler used to build R
  LOCAL_SOFT    absolute path to '/usr/local' software collection
  R_TOOLS_SOFT  absolute path to 'R tools' software collection
  OBJDUMP       command to dump objects

Report bugs at <https://bugs.R-project.org>.

@CGMossa
Copy link
Author

CGMossa commented May 5, 2023

@richfitz
Copy link
Member

richfitz commented May 5, 2023

Thanks - that part is straightforward and we do it elsewhere (for example https://github.com/mrc-ide/eigen1/blob/master/src/util.c#L16-L17) - the pain comes when users have not correctly installed the fortran parts of the toolchain - and on macs that changes every couple of years as apple and R-core change how things get installed.

The blocker on this is the odin syntax, and that's been unresolved for about 5 years so I doubt we will get to it soon!

@CGMossa
Copy link
Author

CGMossa commented May 5, 2023

Good. I won't comment on the syntax just yet.. Especially since I don't know anything about parsers. I guess the problem is that right now the line order doesn't matter, but for the three-step definition it would need to? In any case, thanks for indulging this conversation.

I guess, for my personal understanding, on Windows we have Rblas.dll, and I had hoped it was possible to just link to that, and not need a Fortran compiler. On Windows however, we have Rtools, and most likely it also contains Fortran compiler.. So I don't really have experience with this. I would have guessed -shared plus linking to those Rblas.dll or equivalent elsewhere would have been enough...

@richfitz
Copy link
Member

richfitz commented May 5, 2023

Windows tends to be fine because R core controls the whole toolchain. On mac, at linking, you get issues if libgfortran is not found

Line order won't matter for this either - the intention is to support y <- A %*% x and convert that to the appropriate blas call based on what we know about y, A and x. The issue is when (inevitably) people want to apply these transformations to higher order objects, so looping over part of y at each operation, so we're thinking about things like:

y[., ] <- A[j, ., .] %*% x[., j]

at the moment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants