def convert_state(self, state):
if self.state_converting_basis == "polynomial":
if self.degree == 1 and self.intercept == 0:
return state
else:
poly = PolynomialFeatures(degree=self.degree, include_bias=(self.include_bias!=0))
return poly.fit_transform(state)
elif self.state_converting_basis == "rbf":
rbf = RBFSampler(gamma=1.0, n_components=self.n_components)
return rbf.fit_transform(state)
elif self.state_converting_basis == "one_hot":
encoder = OneHotEncoder(sparse=False, categories=[self.state_space])
return encoder.fit_transform(state.reshape(-1, 1))
else:
return state
def compute_transformed_state_dim(self):
if self.state_converting_basis == "rbf":
return self.n_components
elif self.state_converting_basis == "polynomial":
main_state_dim = int(self.state_dim-1) if self.team_indicator else int(self.state_dim)
n_terms = math.comb(main_state_dim + self.degree, self.degree)
return n_terms if self.include_bias else n_terms-1
elif self.state_converting_basis == "one_hot":
return len(self.state_space)
else:
return int(self.state_dim-1) if self.team_indicator else int(self.state_dim)