def toy_problem(x,proba_flip=0.05):
out = (x[:,0]>0) ^ (x[:,1]>0)
# out = (x[:,1]**2-2*x[:,0]>0) ^ (x[:,1]**3-x[:,0]**2>0)
flip = (torch.rand(x.shape[0]) < proba_flip).to(x.device)
return (out & flip) | (torch.logical_not(out) & torch.logical_not(flip))
toy_set_x = torch.randn((1000,2)).to(device)
toy_set_y = toy_problem(toy_set_x).to(device)
def toy_viz_surf(model):
res = 2**8
x = torch.linspace(-2,2,res)
y = torch.linspace(-2,2,res)
grid_x, grid_y = torch.meshgrid((x,y))
coords = torch.stack([grid_x,grid_y])
coords = coords.permute(1,2,0).view(-1,2).to(device)
L = nn_freeze(nn.Linear(2,3).to(device))
L.weight.data = torch.tensor([[.5,0],[0,0],[0,.5]]).to(device)
L.bias.data = torch.zeros(3).to(device)
pixels = L(model(coords)).view(res,res,3)
plt.imshow(pixels.cpu().detach(),extent=(-2, 2, -2, 2))
def toy_viz(model,toy_set_x,toy_set_y):
toy_nn_y = torch.argmax(model(toy_set_x),dim=1)
toy_nn_y_err = (1*toy_set_y - toy_nn_y) == 0
plt.figure(figsize=(20,6))
plt.subplot(1,3,1)
plt.scatter(toy_set_x[ toy_set_y][:,0].cpu(), toy_set_x[ toy_set_y][:,1].cpu(),c='red',s=2)
plt.scatter(toy_set_x[~toy_set_y][:,0].cpu(), toy_set_x[~toy_set_y][:,1].cpu(),c='blue',s=2)
toy_viz_surf(lambda x: 1.*F.one_hot(1*toy_problem(x,proba_flip=0)))
plt.title("ground truth + noised samples")
plt.subplot(1,3,2)
plt.scatter(toy_set_x[toy_nn_y==1][:,0].cpu(), toy_set_x[toy_nn_y==1][:,1].cpu(),c='red',s=2)
plt.scatter(toy_set_x[toy_nn_y==0][:,0].cpu(), toy_set_x[toy_nn_y==0][:,1].cpu(),c='blue',s=2)
plt.title("model's perception + relabeled samples")
toy_viz_surf(model)
plt.subplot(1,3,3)
# plt.scatter(toy_set_x[:,0].cpu(), toy_set_x[:,1].cpu(),c='white',s=2)
plt.scatter(toy_set_x[~toy_nn_y_err][:,0].cpu(), toy_set_x[~toy_nn_y_err][:,1].cpu(),c='green',s=2)
toy_viz_surf(lambda x: 1. - torch.abs(1.*F.one_hot(1*toy_problem(x,proba_flip=0))-model(x)))
plt.title("difference between model and ground truth + label mismatch")
def toy_train(model,epochs,lr=1e-4):
if not hasattr(model, 'my_optimizer'):
model.my_optimizer = torch.optim.Adam(model.parameters(), lr=lr)
toy_opti = model.my_optimizer
optim_lr(toy_opti,lr)
plt.figure(figsize=(20,5))
for epoch in range(epochs): # it's more like sample count as dataset is generated live
# shape = [batch_size,2]
# x = torch.randn(shape).to(device)
x = toy_set_x
model.train()
toy_opti.zero_grad()
loss = F.mse_loss(model(x), 1.*F.one_hot(1*toy_problem(x,proba_flip=0)))
loss.backward()
toy_opti.step()
# print(loss.item())
plt.scatter(epoch,np.log(loss.item()),s=2)
# infos
N = 1000
toy_set_test_x = torch.randn((N,2)).to(device)
toy_set_test_y = toy_problem(toy_set_test_x).to(device)
toy_nn_test_y = torch.argmax(model(toy_set_test_x),dim=1)
correct_count = torch.sum(toy_nn_test_y-toy_set_test_y*1==0)
toy_viz(model,toy_set_test_x,toy_set_test_y)
print(f"correct rate : {correct_count/N}")
toy_nn = nn.Sequential(*[
nn.Linear(2,16),nn.ReLU(),
# nn.Linear(16,16),nn.ReLU(),
# nn.Linear(16,16),nn.ReLU(),
# nn.Linear(16,16),nn.ReLU(),
# nn.Linear(16,16),nn.ReLU(),
nn.Linear(16,2),nn.Softmax(),
]).to(device)
toy_train(toy_nn,2**12,lr=1e-2)
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.transforms import ToTensor
storage_path = '/content/drive/My Drive/torch-dataset/'
mnist_set_train = datasets.MNIST(root=storage_path, download=True, train=True , transform=ToTensor())
mnist_set_test = datasets.MNIST(root=storage_path, download=True, train=False, transform=ToTensor())
# mnist_set_train = datasets.FashionMNIST(root=storage_path, download=True, train=True , transform=ToTensor())
# mnist_set_test = datasets.FashionMNIST(root=storage_path, download=True, train=False, transform=ToTensor())
set_train = mnist_set_train.data.to(device).float().view(-1,1,28,28)/255.
set_train_y = F.one_hot(mnist_set_train.targets.to(device)).float()
set_test = mnist_set_test .data.float().to(device).view(-1,1,28,28)/255.
set_test_y = F.one_hot(mnist_set_test .targets.to(device)).float()
def mnist_train(model,epochs,batch_size=1000,lr=1e-4):
model.train()
if not hasattr(model, 'my_optimizer'):
model.my_optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model_opti = model.my_optimizer
optim_lr(model_opti,lr)
plt.figure(figsize=(20,10))
for epoch in range(epochs):
losses = []
for i in range(0,len(set_train),batch_size):
x = set_train [i:min(len(set_train)-1,i+batch_size)]
y = set_train_y[i:min(len(set_train)-1,i+batch_size)]
model_opti.zero_grad()
loss = F.mse_loss(model(x), y)
loss.backward()
model_opti.step()
losses = np.append(losses,loss.item())
plt.scatter(epoch,np.log(losses.mean()),s=2) # yea losses.mean() isn't perfect, but it's mostly fine
plt.xscale('log')
mnist_perf(model)
def mnist_perf(model):
y = torch.argmax( set_test_y, dim=1)
y_nn = torch.argmax(model(set_test) , dim=1)
correct_rate = torch.sum(y-y_nn==0).item() / len(set_test)
print(f"correct rate on test set = {correct_rate}")
y = torch.argmax( set_train_y, dim=1)
y_nn = torch.argmax(model(set_train) , dim=1)
correct_rate = torch.sum(y-y_nn==0).item() / len(set_train)
print(f"correct rate on train set = {correct_rate}")
model = nn.Sequential(*[
nn.Conv2d(1 ,16,(3,3)), nn.ReLU(), nn.MaxPool2d((2,2)),
nn.Conv2d(16,32,(3,3)), nn.ReLU(), nn.MaxPool2d((2,2)),
nn.Conv2d(32,64,(3,3)), nn.ReLU(), nn.MaxPool2d((2,2)),
nn.Flatten(), nn.Linear(64,10), nn.Softmax(1),
]).to(device)
mnist_train(model,200)