
import numpy as np
from torch import nn,optim
from torch.utils.data import Dataset,TensorDataset,Dataloader
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.tensorboard import SummaryWriter
import random

import torch

def tensorDataGenRe(num_features=1000, w=[2, 1, -1], bias=True, belta=0.1, bag=1):
    :param num_features: feature count
    :param w: w
    :param bias: b
    :param belta: interferce  term coefficient
    if bias == True:
        input_num = len(w) - 1
        features_ture = torch.randn(num_features, input_num, dtype=torch.float32)
        w_ture = torch.tensor(w[:-1]).reshape(-1, 1).float()
        b_ture = torch.tensor(w[-1]).reshape(-1, 1).float()
        if input_num == 1:

            labels_ture = torch.pow(features_ture, bag) * w_ture + b_ture
            labels_ture = torch.mm(torch.pow(features_ture, bag), w_ture) + b_ture
        features = torch.cat((features_ture, torch.ones(len(features_ture), 1)), 1)
        labels = labels_ture + torch.randn(size=labels_ture) * belta
        input_num = len(w)
        features = torch.randn(num_features, input_num, dtype=torch.float32)
        w_ture = torch.tensor(w).reshape(-1, 1).float()
        if input_num == 1:
            labels_ture = torch.pow(features, bag) * w_ture
            labels_ture = torch.mm(torch.pow(features, bag), w_ture)
        # features = torch.cat((features_ture,torch.ones(len(features_ture),1),1))
        labels = labels_ture + torch.randn(size=labels_ture) * belta
    return features, labels

def tensorDataGenCla(feature_count=500, feature_class=2, class_count=3, big_size=[4, 2], bais=False):
    standard_size = torch.empty(feature_count, 1, dtype=torch.float32)  # be used for pretend labels
    mean_ = big_size[0]
    std_ = big_size[1]
    ll = []  # labels
    lf = []  # feature
    w = mean_ * (class_count - 1) / 2
    for i in range(class_count):
        features = torch.normal(i * mean_ - w, std_, size=(feature_count, feature_class))
        labels = torch.full_like(standard_size, i)
    feature = torch.cat(lf).float()
    label = torch.cat(ll).float()
    # print(len(feature))
    if bais == True:
        feature = torch.cat((feature, torch.ones(len(feature), 1)), 1)
    return feature, label

def data_iter(features, bach_size, labels):
    :param features: need features
    :param bachsize: Size of each dataset
    :param labels: need labels
    num_features = len(features)
    indeces = list(range(num_features))
    l = []
    for i in range(0, num_features, bach_size):
        j = torch.tensor(indeces[i:min(i + bach_size, num_features)])
        l.append([torch.index_select(features, 0, j), torch.index_select(labels, 0, j)])
    return l

if __name__ == '__main__':
    # a,b = tensorDataGenCla(bais=True)
    # print(a)
    # print(b)
    # tensorDataGenCla()
