【python】数据转换basis function并计算转换之后的维数

 def convert_state(self, state):
        if self.state_converting_basis == "polynomial":
            if self.degree == 1 and self.intercept == 0:
                return state
            else:
                poly = PolynomialFeatures(degree=self.degree, include_bias=(self.include_bias!=0))
                return poly.fit_transform(state)
        elif self.state_converting_basis == "rbf":
            rbf = RBFSampler(gamma=1.0, n_components=self.n_components)
            return rbf.fit_transform(state)
        elif self.state_converting_basis == "one_hot":
            encoder = OneHotEncoder(sparse=False, categories=[self.state_space])
            return encoder.fit_transform(state.reshape(-1, 1))
        else:
            return state
        

    
    def compute_transformed_state_dim(self):
        if self.state_converting_basis == "rbf":
            return self.n_components
        elif self.state_converting_basis == "polynomial":
            main_state_dim = int(self.state_dim-1) if self.team_indicator else int(self.state_dim)
            n_terms = math.comb(main_state_dim + self.degree, self.degree)
            return n_terms if self.include_bias else n_terms-1
        elif self.state_converting_basis == "one_hot":
            return len(self.state_space)
        else:
            return int(self.state_dim-1) if self.team_indicator else int(self.state_dim)  # default

你可能感兴趣的:(python学习,python,开发语言)