jax循环语句
import jax
import jax.numpy as jnp
### 1. jax.lax.while_loop
# jax.lax.while_loop(cond_fun, body_fun, init_val)
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
result = jax.lax.while_loop(cond_fun, body_fun, init_val)
print(result)
### 2. jax.lax.fori_loop
# jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
result = jax.lax.fori_loop(start, stop, body_fun, init_val)
print(result)
### 3.jax.lax.scan
# jax.lax.scan(f, init, xs, length=None, reverse=False, unroll=1)
# 使用scan的时候,carry的变量也是需要显式地定义在函数中,并且是return的第一个变量
def f(carry, x):
x = carry + x
return x, x
xs = jnp.array([0, 1, 2, 3,])
result = jax.lax.scan(f, 0, xs)
print(result)
#haiku.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]
#Equivalent to jax.lax.scan() but with Haiku state passed in/out.
参考:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#
https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=haiku.scan#haiku.scan
原文地址:https://blog.csdn.net/qq_27390023/article/details/135712369
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!