Skip to content

Commit

Permalink
syntax_extensions.md: fix headings and Table of Contents
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 8, 2024
1 parent d503d48 commit dfc1858
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions lib/syntax_extensions.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
# Syntax extensions `%cd` and `%op` {#syntax-extensions-cd-and-op}
# Syntax extensions %cd and %op

- Table of contents
- [Preliminaries](#preliminaries)
- [The syntax for %op {#syntax-for-op}](#syntax-for-op)
- [The syntax for %cd](#syntax-for-cd)
- [The syntax for %op](#the-syntax-for-op)
- [The syntax for %cd](#the-syntax-for-cd)
- [Numeric and N-dimensional array literals](#numeric-and-n-dimensional-array-literals)
- [Wildcard bindings](#wildcard-bindings)
- [Inline declarations](#inline-declarations)
- [Using OCANNL's generalized einsum notation](#using-ocannls-generalized-einsum-notation)
- [Further features of the syntax extension %cd](#features-of-syntax-cd)
- [Syntax of the generalized einsum notation](#syntax-of-the-generalized-einsum-notation)
- [Further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd)
- [Referencing arrays: tensor value, tensor gradient, merge buffer of a tensor node](#referencing-arrays-tensor-value-tensor-gradient-merge-buffer-of-a-tensor-node)
- [Block comments](#block-comments)
- [Further features of the syntax extension %op](#features-of-syntax-op)
- [Further features of the syntax extension %op](#further-features-of-the-syntax-extension-op)
- [Name from binding](#name-from-binding)
- [Label from function argument](#label-from-function-argument)
- [Configuring inline declarations: inline output dimensions, initial values](#configuring-inline-declarations-inline-output-dimensions-initial-values)
- [Lifting of the applications of ~config arguments: if it's an error, refactor your code](#lifting-of-the-applications-of-config-arguments-if-its-an-error-refactor-your-code)
- [Lifting of the applications of config arguments: if an error, refactor your code](#lifting-of-the-applications-of-config-arguments-if-an-error-refactor-your-code)
- [Implementation details](#implementation-details)
- [The hard-coded to-the-power-of operator](#the-hard-coded-to-the-power-of-operator)
- [Intricacies of the syntax extension %cd](#implementation-extension-cd)
- [Intricacies of the syntax extension %cd](#intricacies-of-the-syntax-extension-cd)
- In a nutshell
- Syntax extension `%cd` stands for "code", to express assignments and computations: `Assignments.comp`.
- Syntax extension `%op` stands for "operation", to express tensors: `Tensor.t`.
Expand All @@ -36,7 +37,7 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i

The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding operators.

## The syntax for `%op` {#syntax-for-op}
## The syntax for %op

The `%op` syntax is simpler than the `%cd` syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:

Expand Down Expand Up @@ -70,7 +71,7 @@ When there is a function directly under the `%op` extension point, like in the e

When the declaration is followed by a literal float, the float provides the initial value to initialize the tensor. Otherwise, the tensor value cells are initialized randomly with uniform distribution.

## The syntax for `%cd` {#syntax-for-cd}
## The syntax for %cd

The basic building blocks of the `%cd` syntax are individual assignments, separated by semicolons. The assignments, represented via `Assignments.Accum_binop` and `Assignments.Accum_unop`, are in full generality accumulating:

Expand Down Expand Up @@ -98,7 +99,7 @@ type Assignments.t =

For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.

The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `?/`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#features-of-syntax-cd).
The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `?/`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).

How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default is a pointwise operation.

Expand Down Expand Up @@ -190,8 +191,8 @@ The specification syntax has two modes:

The syntax of a generalized einsum spec has two variants:

- unary: "\<rhs\> shape spec `=>` \<lhs\> shape spec", specifies a unary assignment `<lhs> <asgn-op> <rhs>` (see [syntax for `%cd`](#syntax-for-cd)),
- binary: "\<rhs1\> shape spec `;` \<rhs2\> shape spec `=>` \<lhs\> shape spec", specifies a binary assignment `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (see [syntax for `%cd`](#syntax-for-cd)).
- unary: "\<rhs\> shape spec `=>` \<lhs\> shape spec", specifies a unary assignment `<lhs> <asgn-op> <rhs>` (see [syntax for `%cd`](#the-syntax-for-cd)),
- binary: "\<rhs1\> shape spec `;` \<rhs2\> shape spec `=>` \<lhs\> shape spec", specifies a binary assignment `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (see [syntax for `%cd`](#the-syntax-for-cd)).

Recall that a tensor _shape_ is composed of three _rows_, i.e. sequences of axes: batch, input and output axes. Correspondingly, a shape spec in the notation can be:

Expand Down Expand Up @@ -229,7 +230,7 @@ Examples:
- `..v..|ijk => ..v..kji`: reverse the three rightmost output axes, reduce any other output and input axes, pointwise for batch axes, pairing the batch axes with the leftmost output axes of the result.
- `2..v..|... => ..v..`: slice the tensor at dimension 2 of the leftmost batch axis, reduce all its input and output axes, preserve its other batch axes as output axes.

## Further features of the syntax extension `%cd` {#features-of-syntax-cd}
## Further features of the syntax extension %cd

### Referencing arrays: tensor value, tensor gradient, merge buffer of a tensor node

Expand All @@ -256,7 +257,7 @@ type Assignments.t =

Schematic example: `~~("space" "separated" "comment" "tensor p debug_name:" p; <scope of the comment>)`. The content of the comment uses application syntax, must be composed of strings, `<tensor>`, `<tensor>.value` (equivalent to `<tensor>`), `<tensor>.grad` components, where `<tensor>` is any tensor expression or tensor identifier.

## Further features of the syntax extension `%op` {#features-of-syntax-op}
## Further features of the syntax extension %op

### Name from binding

Expand All @@ -281,7 +282,7 @@ If it is a list expression following an inline declaration, the expression is pa
...
```

### Lifting of the applications of `~config` arguments: if it's an error, refactor your code
### Lifting of the applications of config arguments: if an error, refactor your code

If you recall, inline declared param tensors get lifted out of functions except for the function `fun ~config ->`, where they get defined. Our example `let%op mlp_layer ~config x = ?/ ("w" * x + "b" config.hid_dim)` translates as:

Expand Down Expand Up @@ -369,7 +370,7 @@ let rec pointpow ?(label : string list = []) ~grad_spec p t1 : Tensor.t =

On the `Tensor` level, this is implemented as a binary tensor operation, but it is exposed as a unary tensor operation! To avoid the complexities of propagating gradient into the exponent, `Operation.pointpow` is implemented as a function of only one tensor, the exponent is a number. We hard-code the pointwise-power-of operator `NTDSL.O.( **. )`, resp. `TDSL.O.( **. )`, in the `%cd` and `%op` syntaxes, to pass the numeric value to `pointpow` (the second argument of `**.`) without converting it to a tensor first.

### Intricacies of the syntax extension `%cd` {#implementation-extension-cd}
### Intricacies of the syntax extension %cd

The syntax `%cd` translator needs to accomplish more than a context-free conversion of a concise notation to an `Assignments.t` data-type.

Expand Down

0 comments on commit dfc1858

Please sign in to comment.