# -*- coding: utf-8 -*- """ Simple demonstration of the Ellipsoid online learner on mnist @author: shais """ import numpy as np import matplotlib.pyplot as plt dot = np.dot; sign = np.sign; outer = np.outer; zeros = np.zeros; eye = np.eye; sqrt = np.sqrt; #%% # read data dataDir = "/Users/shais/data/mnist/mnist/"; X = np.loadtxt(dataDir + "train4vs7_data.txt.gz"); Y = np.loadtxt(dataDir + "train4vs7_labels.txt.gz"); d,n = X.shape; #%% # show some images plt.figure(1); for i in range (1,26): ax = plt.subplot(5,5,i); ax.axis('off'); if Y[i]>0: ax.imshow(X[:,i].reshape(28,28),cmap="gray"); else: ax.imshow(255-X[:,i].reshape(28,28),cmap="gray"); plt.draw(); #%% # Initial Ellipsoid learner w = zeros((d,)); A = eye(d); M = 0; # counts mistakes #%% # Loop Ellipsoid over data eta = d*d/(d*d-1.0); for t in range(0,n): yhat = sign(dot(w,X[:,t])); if Y[t] != yhat: M = M+1; Ax = dot(A , X[:,t]); xAx = dot(X[:,t] , Ax); w = w + Y[t]/((d+1)*sqrt(xAx)) * Ax; A = eta*( A - (2.0/((d+1.0)*xAx)) * outer(Ax,Ax) ); #%% # show the mask learnt by ellipsoid plt.figure(2); ax1 = plt.subplot(1,2,1); ax1.axis('off'); # no need for axis marks ax2 = plt.subplot(1,2,2); ax2.axis('off'); # no need for axis marks ax1.imshow(w.reshape(28,28),cmap="gray"); tmp = 1/(1+np.exp(-10*w/w.max())); ax2.imshow(tmp.reshape(28,28),cmap="gray"); plt.draw(); #%%