提问人:user1168149 提问时间:11/7/2023 更新时间:11/7/2023 访问量:40
嵌套类方法的 JAX @jit
JAX @jit for nested class method
问:
我正在尝试使用嵌套函数,但遇到了问题。
我有一个类,它用一个方法接受另一个类。
我想将此方法称为 jitted from .
我想我遵循了 JAX 的常见问题解答,“如何将 jit 与方法一起使用?https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods但是,我遇到了一个错误,说.
谁能告诉我如何解决这个问题?@jit
One
Plant
func
func
One
TypeError: One.__init__() got multiple values for argument 'plant'
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {'plant':self.plant} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
import pdb; pdb.set_trace();
return cls(*children, **aux_data)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z,kk):
self.z =z
@jit
def func(self,y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z,) # arrays / dynamic values
aux_data = None # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
最后一行给了我一个上面描述的错误。
答:
0赞
jakevdp
11/7/2023
#1
您在两个类的 and 代码中都有问题。tree_flatten
tree_unflatten
One._tree_flatten
被视为静态数据,但事实并非如此:它是一个具有非静态元素的 pytree。plant
One._tree_unflatten
实例化参数的顺序错误,导致您看到的错误One
Plant.__init__
对参数不执行任何操作。kk
Plant._tree_unflatten
缺少参数,并且无法将参数传递给aux_data
kk
Plant.__init__
修复这些问题后,代码将执行而不会出错:
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.plant, self.x)
aux_data = None
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z, kk):
self.kk = kk
self.z =z
@jit
def func(self, y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z, self.kk)
aux_data = None
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
评论
0赞
user1168149
11/7/2023
非常感谢你,我很抱歉发布了一个混乱的代码。
评论