USING SAS TO FIND THE BEST K FOR K-NEAREST-NEIGHBOR CLASSIFICATION

******(1) USE K-MEANS CLUSTERING TO FIND NEAREAST NEIGHBORS****************;
proc modeclus data = sashelp.iris m = 1 k = 4 out = _test1 neighbor;
   var petallength petalwidth sepallength sepalwidth;
   ods output neighbor = _test2;
run;
ods html style = harvest image_dpi = 400;
proc sgplot data=_test1;
   scatter y = density x = species / datalabel = cluster;
run;
data _test3;
   set _test2; retain _tmpid;
   if missing(id) = 0 then _tmpid = id; else id = _tmpid;
run;
data _test4 _test5;
   set _test3; by id notsorted;
   if first.id then neighbor = 0;
   neighbor + 1; output _test4;
   if last.id then output _test5;
run;
ods graphics / width = 6in height = 1in ;
proc sgplot data = _test4;
   vbar id / response = distance group = neighbor;
   xaxis display = none grid;
run;
proc sgplot data = _test5;
   series x = id y = neighbor;
   xaxis display = none grid;
   yaxis values = (1 to 6) label = 'No. of neighbors';
run;

******(2) PARTITION RAW DATASET INTO TRAINING AND VALIDATION DATASETS******;
%macro partition(data = , target = , smpratio = ,
                  seed = , train = , validate = );
/**************************************************************
* MACRO: partition()
* GOAL: divide to training and validation sets that
* represent original target variable's proportion
* PARAMETERS: data = raw dataset
* target = target variable
* smprate = ratio between training and validation
* set * seed = random seed for sampling
**************************************************************/
ods select none;
ods output variables = _varlist;
proc contents data = &data;
run;

proc sql;
   select variable into: num_var separated by ' '
   from _varlist
   where lowcase(type) = 'num';
quit;

proc sort data = &data out = _tmp1;
   by ⌖
run;

proc surveyselect data = _tmp1 samprate = &smpratio
   out = _tmp2 seed = &seed outall;
   strata &target / alloc = prop;
run;

data &train &validate;
   set _tmp2; keep &num_var ⌖
   if selected = 0 then output &train;
   else output &validate;
run;

proc datasets nolist;
   delete _:;
quit;
ods select all;
%mend;

%partition(data = sashelp.iris, target = species, smpratio = 0.5, 
       seed = 20110901, train = iris_train, validate = iris_validate);
%partition(data = sashelp.cars, target = origin, smpratio = 0.5,
       seed = 20110901, train = cars_train, validate = cars_validate);

********(3) BUILD A USER-DEFINED FUNCTION TO IMPLEMENT K-NN************************;
option mstored sasmstore = sasuser;
%macro knn_macro / store source;
%let target = %sysfunc(dequote(&target));
%let input = %sysfunc(dequote(&input));
%let train = %sysfunc(dequote(&train));
%let validate = %sysfunc(dequote(&validate));
%let error = 0;
%if %length(&k) = 0 %then %do;
%put ERROR: Value for K is missing ;
%let error = 1;
%end;
%else %if %eval(&k) le 0 or %sysfunc(anydigit(&k)) = 0 %then %do;
%put ERROR: Value for K is invalid ;
%let error = 1;
%end;
%if %length(&target) = 0 %then %do;
%put ERROR: Value for target is missing ;
%let error = 1;
%end;
%if %length(&input) = 0 %then %do;
%put ERROR: Value for INPUT is missing ;
%let error = 1;
%end;
%if %sysfunc(exist(&train)) = 0 %then %do; %put ERROR: Training dataset does not exist ;
%let error = 1;
%end;
%if %sysfunc(exist(&validate)) = 0 %then %do;
%put ERROR: validation dataset does not exist ;
%let error = 1;
%end;
%if &error = 1 %then %goto finish;

ods output classifiedtestclass = _classifiedtestclass;
proc discrim data = &train  test = &validate  testout = _scored
   method = npar k = &k testlist ;
   class ⌖
   var &input;
run;

data _null_;
   set _scored nobs = nobs end = eof;
   retain count;
   if &target ne _into_ then count + 1;
   if eof then do;
   misc = count / nobs;
   call symput('misc', misc);
   end;
run;
%finish:;
%mend;

proc fcmp outlib = sasuser.knn.funcs;
/***********************************************************
* FUNCTION: knn() * GOAL: apply k-Nearest-Neighbor for classification
* INPUT: k = number of nearest neighbours
* train = training dataset
* validate = validation dataset
* target = target variable
* input = input variables
* OUTPUT: overall misclassification rate
***********************************************************/
   function knn(k, train $, validate $, target $, input $);
   rc = run_macro('knn_macro', k, train, validate, target, input, misc);
   if rc eq 0 then return(misc);
   else return(.); 
   endsub;
run;

******(3) APPLY K-NN FUNCTION TO CLASSIFY IRIS AND CARS DATA****************;
%macro errorchk(train = , validate = , target = , input = , k = );
/***********************************************************
* MACRO: errorchk() * GOAL: use knn()function and visualize result
* PARAMETERS: train = training dataset
* validate = validation dataset
* target = target variable
* input = input variables
* k = number of nearest neighbors
***********************************************************/
option cmplib = (sasuser.knn) mstored sasmstore = sasuser;

data _null_;
   misc_rate = knn(&k, symget('train'), symget('validate'),
   symget('target'), symget('input'));
   call symput('misc_rate', misc_rate);
run;

proc sql noprint;
   select distinct &target into : varlist1 separated by ' '
      from &validate;
   select distinct cats("'", lowcase(&target), "'")
         into: varlist2 separated by ','
      from &validate;
quit;

proc transpose data = _classifiedtestclass out = _out1;
   by from&target notsorted;
   var &varlist1;
run;

data _out2;
   set _out1;
   where lowcase(from&target) in (&varlist2);
   label _name_ = 'Level';
run;

proc sgplot data = _out2;
   vbar from&target / response = col1 group = _name_;
   xaxis label = 'Real';
   yaxis label = 'Classified ';
   inset "Overall misclassification rate is:
   %sysfunc(putn(&misc_rate, percent8.2))" / position = topright;
run;
%mend;

ods graphics / width= 400px height = 300px ;
%errorchk(train = iris_train, validate = iris_validate, target = species,
           input = petallength petalwidth sepallength sepalwidth, k = 5);
%errorchk(train = cars_train, validate = cars_validate, target = origin,
           input = invoice wheelbase length, k = 5);

******(4) VISUALIZE CLASSIFICATION RESULT FOR CARS DATA********************;
proc rank data = _scored groups = 4 out = _out3;
   var invoice;
   ranks q;
run;

proc sort data = _out3 out = _out3;
   by q invoice;
run;
data _out4;
   set _out3;
   by q ;
   retain fmtname "qvar" start end;
   if first.q then start = invoice;
   if last.q then end = invoice;
   if last.q; length label $35;
      q + 1;
      label = cat(q,'Qu.:', '$',start,'-','$',end);
run;

proc format cntlin = _out4 fmtlib;
run;
data _out5(keep=level name invoice wheelbase length q qfmt);
   set _out3;
   qfmt = put(invoice, qvar.);
   level = _into_;
   name = '2-classified';
     output;
   level = origin;
   name = '1-real';
     output;
run;

ods graphics / width = 700px height = 500px ;
proc sgpanel data = _out5;
   panelby qfmt name / layout = lattice onepanel novarname;
   scatter x = Wheelbase y = length / group = level ;
run;

******(5) USE LOGISTIC REGREESION TO CLASSIFY CARS DATA********************;
proc logistic data = cars_train;
   model origin = invoice wheelbase length / link = glogit;
   score data = cars_validate out = logitscored;
run;
proc freq data = logitscored;
   table f_origin*i_origin / nocol nocum nopercent;
run;

******(6) RUN LOOPS ON K-NN TO FIND THE BEST K VALUE***********************;
%macro findk(train = , validate =, target = , input =, maxk =);
/***********************************************************
* MACRO: findk()
* GOAL: visualize results of k-NNs by loops
* PARAMETERS: train = training dataset
* validate = validation dataset
* target = target variable
* input = input variables
* maxk = maximum value of k
***********************************************************/
option cmplib = (sasuser.knn) mstored sasmstore = sasuser;
ods select none;
data _tmp3;
   do k = 1 to &maxk ;
      misc_rate = knn(k, symget('train'), symget('validate'),
      symget('target'), symget('input'));
      output;
   end;
run;

proc sql;
   select min(misc_rate) into: min_misc
      from _tmp3;
   select k into: bestk separated by ', '
      from _tmp3
      having misc_rate = min(misc_rate);
quit;

ods select all;
proc sgplot data = _tmp3;
   series x = k y = misc_rate;
   xaxis grid values = (1 to &maxk by 1)
   label = 'k number of neareast neighbours';
   yaxis grid values = ( 0 to 0.5 by 0.05)
   label = 'Misclassification rate';
   refline &min_misc / transparency = 0.3
   label = "k = &bestk";
   format misc_rate percent8.1;
run;

proc datasets nolist;
delete _:;
quit;
%mend;

ods html style = htmlblue;
%findk(train = iris_train, validate = iris_validate, target = species,
        input = petallength petalwidth sepallength sepalwidth, maxk = 20);
%findk(train = cars_train, validate = cars_validate, target = origin,
        input = invoice wheelbase length, maxk = 40);
****** END OF ALL CODING *************************************************;

你可能感兴趣的:(validation,input,dataset,output,Training,classification)