模仿st_table写的StTable类

    update1:添加了remove,removeAll()方法以及getSize()方法
    update2:添加了keySet()方法用于迭代  
    update3:经过测试,StTable类在存储Integer类型key时,put的速度比HashMap快了接近3倍,而remove、get却比HashMap慢;而在存储String类型的key时,put比Hashmap慢,但是get、remove却快不少。

    读ruby hacking guide,其中专门辟了一个章节介绍了st.c中的st_table,这个数据结构也就是类似java中的HashMap,基本原理是利用数组存储,数组的每一个元素是一个单向链表,链表中再存储具体的元素,如下图所示的结构

   ruby中利用这个结构来存储对象变量、类方法、常量、全局变量等信息,因为在c ruby中,方法、变量都是用一个整型作为键值来存储在st_table中,因此这个数据结构对于以整性为键值的map类型来说速度非常不错(我没有测试内存的占用情况)。
源码如下:
// 接口,用于定义hash函数
//HashFunction.java
public   interface  HashFunction < T >  {
   
public   int  hash(T key);
}

链表元素类:
public   class  StTableEntry < T, V >  {
    
protected   int  hash;  // hash值

    
protected  T key;    //

    
protected  V value;  // 存储值

    
protected  StTableEntry < T, V >  next;  // 下一节点

    
public  StTableEntry() {

    }

    
public  StTableEntry( int  hash, T key, V value, StTableEntry < T, V >  next) {
        
super ();
        
this .hash  =  hash;
        
this .key  =  key;
        
this .value  =  value;
        
this .next  =  next;
    }

    
public   int  getHash() {
        
return  hash;
    }

    
public   void  setHash( int  hash) {
        
this .hash  =  hash;
    }

    
public  T getKey() {
        
return  key;
    }

    
public   void  setKey(T key) {
        
this .key  =  key;
    }

    
public  StTableEntry < T, V >  getNext() {
        
return  next;
    }

    
public   void  setNext(StTableEntry < T, V >  next) {
        
this .next  =  next;
    }

    
public  V getValue() {
        
return  value;
    }

    
public   void  setValue(V value) {
        
this .value  =  value;
    }

}

完整的StTable实现,没有实现remove,(update:添加了remove,removeAll()方法以及getSize()方法):
public   final   class  StTable < T, V >  {
    
private  HashFunction < T >  hashFunction;

    
private   int  num_bins;

    
int  num_entries;

    StTableEntry
< T, V > [] bins;

    
public   static   int  DEFAULT_SIZE  =   11 ;

    
private   static   int  DEFAULT_MAX_DENSITY  =   5 ;

    
private   static   int  DEFAULT_MIN_SIZE  =   8 ;

    
private   static   long  primes[]  =  {  8   +   3 16   +   3 32   +   5 64   +   3 128   +   3 ,
            
256   +   27 512   +   9 1024   +   9 2048   +   5 4096   +   3 8192   +   27 ,
            
16384   +   43 32768   +   3 65536   +   45 131072   +   29 262144   +   3 ,
            
524288   +   21 1048576   +   7 2097152   +   17 4194304   +   15 8388608   +   9 ,
            
16777216   +   43 33554432   +   35 67108864   +   15 134217728   +   29 ,
            
268435456   +   3 536870912   +   11 1073741824   +   85 0  };

    
public  StTable(HashFunction < T >  hashFunction) {
        
this .hashFunction  =  hashFunction;
        
this .num_bins  =  DEFAULT_SIZE;
        
this .num_entries  =   0 ;
        
this .bins  =   new  StTableEntry[ this .num_bins];
    }

    
public  StTable(HashFunction < T >  hashFunction,  int  size) {
        
this .hashFunction  =  hashFunction;
        
if  (size  ==   0 )
            
throw   new  IllegalArgumentException(
                    
" The size could not less than zero: "   +  size);
        
this .num_bins  =  size;
        
this .num_entries  =   0 ;
        
this .bins  =   new  StTableEntry[ this .num_bins];
    }

    
private   long  newSize( int  size) {

        
for  ( int  i  =   0 , newsize  =  DEFAULT_MIN_SIZE; i  <  primes.length; i ++ , newsize  <<=   1 ) {
            
if  (newsize  >  size)
                
return  primes[i];
        }
        
/*  Ran out of polynomials  */
        
return   - 1 /*  should raise exception  */
    }

    
public  V get(T key) {
        
int  hash_val  =  doHash(key);
        StTableEntry
< T, V >  entry  =  findEntry(hash_val, key);
        
if  (entry  ==   null )
            
return   null ;
        
else
            
return  entry.getValue();
    }

    
public  V put(T key, V value) {
        
int  hash_val  =  doHash(key);
        StTableEntry
< T, V >  entry  =  findEntry(hash_val, key);
        
if  (entry  ==   null ) {
            
//  未有键值,直接添加
            addDirect(key, value);
            
return  value;
        } 
else  {
            V v 
=  entry.value;
            entry.value 
=  value;
            
return  v;
        }
    }

    
public  V remove(T key) {
        
int  hash_val  =  doHash(key);
        
int  bin_pos  =  hash_val  %   this .num_bins;
        StTableEntry
< T, V >  entry  =   this .bins[bin_pos];
        
//  记录前一节点,考虑修改采用双向链表也可
        StTableEntry < T, V >  prev  =   null ;
        
if  (entryNotEqual(entry, key, hash_val)) {
            prev 
=  entry;
            entry 
=  entry.next;
            
while  (entryNotEqual(entry, key, hash_val)) {
                prev 
=  entry;
                entry 
=  entry.next;
            }
        }
        
if  (entry  ==   null )
            
return   null ;
        
else  {
            
if  (prev  !=   null )
                prev.next 
=  entry.next;  //  前一节点的next连接到下一节点
             else
                
this .bins[bin_pos]  =  entry.next;  //  entry恰好是第一个节点,将数组元素设置为next
            V v  =  entry.value;
           
entry  =   null //  gc友好
             return  v;
        }
       
this .num_entries =0;
    }

    
public   void  removeAll() {
        
for  ( int  i  =   0 ; i  <   this .bins.length; i ++ ) {
            StTableEntry
< T, V >  entry  =   this .bins[i];
            
this .bins[i]  =   null ;
            StTableEntry
< T, V >  temp  =  entry;
            
if  (entry  ==   null )
                
continue ;
            
while  (entry  !=   null ) {
                entry 
=   null ;
                
this .num_entries -- ;
                entry 
=  temp.next;
                temp 
=  entry;
            }
            temp 
=   null ;
            entry 
=   null ;
        }
    }

    
public   int  getSize() {
        
return   this .num_entries;
    }
   
   
public Set<T> keySet() {
        Set<T> keys = new HashSet<T>(this.num_entries);
        for (int i = 0; i < this.bins.length; i++) {
            StTableEntry<T, V> entry = this.bins[i];
            if (entry == null)
                continue;
            while (entry != null) {
                keys.add(entry.key);
                entry = entry.next;
            }

        }
        return keys;
    }
    
//  hash函数,调用hashFunction的hash方法
     private   int  doHash(T key) {
        
if  (hashFunction.hash(key)  <   0 )
            
throw   new  IllegalArgumentException(
                    
" hash value could not less than zero: "
                            
+  hashFunction.hash(key));
        
return  hashFunction.hash(key);
    }

    
//  过于拥挤,重新分布
     private   void  reHash() {
        
int  new_size  =  ( int ) newSize( this .num_bins);
        StTableEntry
< T, V > [] new_bins  =  (StTableEntry < T, V > [])  new  StTableEntry[new_size];
        
for  ( int  i  =   0 ; i  <   this .num_bins; i ++ ) {
            StTableEntry
< T, V >  entry  =   this .bins[i];
            
while  (entry  !=   null ) {
                StTableEntry
< T, V >  next  =  entry.next;
                
int  hash_val  =  entry.hash  %  new_size;
                entry.next 
=  new_bins[hash_val];
                new_bins[hash_val] 
=  entry;
                entry 
=  next;
            }
        }
        
this .bins  =   null ; //  gc友好
         this .num_bins  =  new_size;
        
this .bins  =  new_bins;

    }

    
private   void  addDirect(T key, V value) {
        
int  hash_val  =  doHash(key);
        
int  bin_pos  =  hash_val  %   this .num_bins;
        
if  (( this .num_entries  /   this .num_bins)  >  DEFAULT_MAX_DENSITY) {
            reHash();
            bin_pos 
=  hash_val  %   this .num_bins;
        }
        StTableEntry
< T, V >  entry  =   new  StTableEntry < T, V > ();
        entry.setHash(hash_val);
        entry.setKey(key);
        entry.setValue(value);
        entry.setNext(
this .bins[bin_pos]);
        
this .bins[bin_pos]  =  entry;
        
this .num_entries ++ ;
    }

    
private  StTableEntry < T, V >  findEntry( int  hash_val, T key) {
        
int  bin_pos  =  hash_val  %   this .num_bins;
        StTableEntry
< T, V >  entry  =   this .bins[bin_pos];
        
if  (entryNotEqual(entry, key, hash_val)) {
            entry 
=  entry.next;
            
while  (entryNotEqual(entry, key, hash_val)) {
                entry 
=  entry.next;
            }
        }
        
return  entry;
    }

    
//  判断元素是否相同
     private   boolean  entryNotEqual(StTableEntry < T, V >  entry, T key,  int  hash_val) {
        
return  entry  !=   null
                
&&  (entry.getHash()  !=  hash_val  ||  ( ! key.equals(entry.getKey())));
    }

}

  单元测试类就不列了,给一个与HashMap的简单性能对比,以整型为键,显然StTable快多了,对于字符串型,关键是HashFunction的定义,我直接调用String的hashCode方法,不知道有没有其他更好的方法让元素分布的更均匀些:
import  java.util.HashMap;
import  java.util.Map;

public   class  Benchmark {
    
public   static   void  main(String args[]) {
       
long  map_cost  =  testStringMap();
        
long  table_cost  =  testStringTable();
        
if  (map_cost  <=  table_cost)
            System.out.println(
" map is faster than table  " );
        
else
            System.out.println(
" table is faster than map  " );

        map_cost 
=  testIntegerMap();
        table_cost 
=  testIntegerTable();
        
if  (map_cost  <=  table_cost)
            System.out.println(
" map is faster than table  " );
        
else
            System.out.println(
" table is faster than map  " );
    }

    
public   static   long  testIntegerMap() {
        Map
< Integer, Integer >  map  =   new  HashMap < Integer, Integer > ();
        
long  start  =  System.nanoTime();
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            map.put(i, i);
        
long  result  =   0 ;
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            result 
+=  map.get(i);
        
long  end  =  System.nanoTime();
        System.out.println(
" result: "   +  result);
        System.out.println(
" map: "   +  (end  -  start));
        
return  (end  -  start);
    }

    
public   static   long  testIntegerTable() {
        HashFunction
< Integer >  intHash  =   new  HashFunction < Integer > () {
            
public   int  hash(Integer key) {
                
return  key;
            }
        };
        StTable
< Integer, Integer >  table  =   new  StTable < Integer, Integer > (intHash);
        
long  start  =  System.nanoTime();
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            table.put(i, i);
        
long  result  =   0 ;
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            result 
+=  table.get(i);
        
long  end  =  System.nanoTime();
        System.out.println(
" result: "   +  result);
        System.out.println(
" table: "   +  (end  -  start));
        
return  (end  -  start);
    }

    
public   static   long  testStringMap() {
        Map
< String, String >  map  =   new  HashMap < String, String > ();
        
long  start  =  System.nanoTime();
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            map.put(String.valueOf(i), String.valueOf(i));
        
long  result  =   0 ;
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            result 
+=  Integer.parseInt(map.get(String.valueOf(i)));
        
long  end  =  System.nanoTime();
        System.out.println(
" result: "   +  result);
        System.out.println(
" map: "   +  (end  -  start));
        
return  (end  -  start);
    }

    
public   static   long  testStringTable() {
        HashFunction
< String >  intHash  =   new  HashFunction < String > () {
            
int  i  =   0 ;
            
public   int  hash(String key) {
                
int  hashCode  =  key.hashCode();
                
return  hashCode  <   0   ?   - hashCode : hashCode;
            }
        };
        StTable
< String, String >  table  =   new  StTable < String, String > (intHash);
        
long  start  =  System.nanoTime();
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            table.put(String.valueOf(i), String.valueOf(i));
        
long  result  =   0 ;
        
for  ( int  i  =   0 ; i  <   10000 ; i ++ )
            result 
+=  Integer.parseInt(table.get(String.valueOf(i)));
        
long  end  =  System.nanoTime();
        System.out.println(
" result: "   +  result);
        System.out.println(
" table: "   +  (end  -  start));
        
return  (end  -  start);
    }

}

结果为:
result:49995000
map:55501468
result:49995000
table:60999652
map is faster than table

 
result:49995000
map:44634444
result:49995000
table:26209477
table is faster than map

将get换成remove方法,结果也与上面的类似。



你可能感兴趣的:(模仿st_table写的StTable类)