taichi.ad

Module Contents

Functions

grad_replaced(func)

A decorator for python function to customize gradient with Taichi's autodiff

grad_for(primal)

Generates a decorator to decorate primal's customized gradient function.

taichi.ad.grad_replaced(func)

A decorator for python function to customize gradient with Taichi’s autodiff system, e.g. ti.Tape() and kernel.grad(). This decorator forces Taichi’s autodiff system to use a user-defined gradient function for the decorated function. Its customized gradient must be decorated by grad_for().

Parameters

fn (Callable) – The python function to be decorated.

Returns

The decorated function.

Return type

Callable

Example:

>>> @ti.kernel
>>> def multiply(a: ti.float32):
>>>     for I in ti.grouped(x):
>>>         y[I] = x[I] * a
>>>
>>> @ti.kernel
>>> def multiply_grad(a: ti.float32):
>>>     for I in ti.grouped(x):
>>>         x.grad[I] = y.grad[I] / a
>>>
>>> @ti.grad_replaced
>>> def foo(a):
>>>     multiply(a)
>>>
>>> @ti.grad_for(foo)
>>> def foo_grad(a):
>>>     multiply_grad(a)
taichi.ad.grad_for(primal)

Generates a decorator to decorate primal’s customized gradient function. See grad_replaced() for examples.

Parameters

primal (Callable) – The primal function, must be decorated by grad_replaced().

Returns

The decorator used to decorate customized gradient function.

Return type

Callable