嵌套类方法的 JAX @jit

JAX @jit for nested class method

提问人:user1168149 提问时间:11/7/2023 更新时间:11/7/2023 访问量:40

问:

我正在尝试使用嵌套函数,但遇到了问题。 我有一个类,它用一个方法接受另一个类。 我想将此方法称为 jitted from . 我想我遵循了 JAX 的常见问题解答,“如何将 jit 与方法一起使用?https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods但是,我遇到了一个错误,说. 谁能告诉我如何解决这个问题?@jitOnePlantfuncfuncOneTypeError: One.__init__() got multiple values for argument 'plant'

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util

class One:
    def __init__(self, plant,x):
        self.plant = plant
        self.x = x
    
    @jit
    def call_plant_func(self,y):
        out = self.plant.func(y) + self.x
        return out
    
    def _tree_flatten(self):
        children = (self.x,)  # arrays / dynamic values
        aux_data = {'plant':self.plant}  # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        import pdb; pdb.set_trace();
        return cls(*children, **aux_data)
        
tree_util.register_pytree_node(One,
                               One._tree_flatten,
                               One._tree_unflatten)    
    
class Plant:
    def __init__(self, z,kk):
        self.z =z
    
    @jit
    def func(self,y):
        y = y + self.z
        return y
    
    def _tree_flatten(self):
        children = (self.z,)  # arrays / dynamic values
        aux_data = None # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, children):
        return cls(*children)
   
tree_util.register_pytree_node(Plant,
                               Plant._tree_flatten,
                               Plant._tree_unflatten)

plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))

最后一行给了我一个上面描述的错误。

嵌套的 JIT JAX

评论


答:

0赞 jakevdp 11/7/2023 #1

您在两个类的 and 代码中都有问题。tree_flattentree_unflatten

  • One._tree_flatten被视为静态数据,但事实并非如此:它是一个具有非静态元素的 pytree。plant
  • One._tree_unflatten实例化参数的顺序错误,导致您看到的错误One
  • Plant.__init__对参数不执行任何操作。kk
  • Plant._tree_unflatten缺少参数,并且无法将参数传递给aux_datakkPlant.__init__

修复这些问题后,代码将执行而不会出错:

class One:
    def __init__(self, plant,x):
        self.plant = plant
        self.x = x
    
    @jit
    def call_plant_func(self,y):
        out = self.plant.func(y) + self.x
        return out
    
    def _tree_flatten(self):
        children = (self.plant, self.x)
        aux_data = None
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children)
        
tree_util.register_pytree_node(One,
                               One._tree_flatten,
                               One._tree_unflatten)    
    
class Plant:
    def __init__(self, z, kk):
        self.kk = kk
        self.z =z
    
    @jit
    def func(self, y):
        y = y + self.z
        return y
    
    def _tree_flatten(self):
        children = (self.z, self.kk)
        aux_data = None
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children)
   
tree_util.register_pytree_node(Plant,
                               Plant._tree_flatten,
                               Plant._tree_unflatten)

plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))

评论

0赞 user1168149 11/7/2023
非常感谢你,我很抱歉发布了一个混乱的代码。