版本2
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Filename:test_expr.py
import unittest
from expr import *
class ExprTestCase(unittest.TestCase):
def setUp(self):
return
def tearDown(self):
return
def test_num_index0(self):
self.assertEqual(1, num("1", 0))
self.assertEqual(2, num("2+3", 0))
def test_num_index2(self):
self.assertEqual(3, num("2+3", 2))
def test_add(self):
self.assertEqual(2, add("1+1"))
self.assertEqual(3, add("1+2"))
def test_add3(self):
self.assertEqual(6, add("1+2+3"))
def test_minus(self):
self.assertEqual(1, minus("2-1"))
self.assertEqual(0, minus("2-1-1"))
def test_mul(self):
self.assertEqual(6, mul("2*3"))
self.assertEqual(24, mul("2*3*4"))
def test_div(self):
self.assertEqual(2, div("6/3"))
self.assertEqual(1, div("9/3/3"))
def test_mix(self):
self.assertEqual(7, mix(1+2*3))
if __name__ == '__main__':
unittest.main()
def num(s, pos):
return int(s[pos])
def add(s):
return chains(s)
def minus(s):
return chains(s)
def mul(s):
return chains(s)
def div(s):
return chains(s)
def apply_mix(val, s, pos):
if s[pos] == '+':
val += num(s, pos + 1)
elif s[pos] == '-':
val -= num(s, pos + 1)
if s[pos] == '*':
val *= num(s, pos + 1)
elif s[pos] == '/':
val /= num(s, pos + 1)
return val
def chains(s):
val = num(s, 0)
for i in range(1, len(s)):
val = apply_mix(val, s, i)
return val
def mix(s):
return add(mul(s))
版本1
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Filename:test_expr.py
import unittest
from expr import *
class ExprTestCase(unittest.TestCase):
def setUp(self):
return
def tearDown(self):
return
def test_num_index0(self):
self.assertEqual(1, num("1", 0))
self.assertEqual(2, num("2+3", 0))
def test_num_index2(self):
self.assertEqual(3, num("2+3", 2))
def test_add(self):
self.assertEqual(2, add("1+1"))
self.assertEqual(3, add("1+2"))
def test_add3(self):
self.assertEqual(6, add("1+2+3"))
def test_minus(self):
self.assertEqual(1, minus("2-1"))
self.assertEqual(0, minus("2-1-1"))
def test_add_minus(self):
self.assertEqual(4, add_minus("3+2-1"))
if __name__ == '__main__':
unittest.main()
def num(s, pos):
return int(s[pos])
def add(s):
return add_minus(s)
def minus(s):
return add_minus(s)
def add_minus(s):
val = num(s, 0)
for i in range(1, len(s)):
if s[i] == '+':
val += num(s, i + 1)
elif s[i] == '-':
val -= num(s, i + 1)
return val
---------------------------------------------------------------------
# num, plus, minus, mul, div
# num('1') -->
# plus('1','2') --> plus(num('1'), num('2'))
def extract(notation):
if len(notation) == 2:
return (notation[0], notation[1], None)
elif len(notation) == 3:
return (notation[0], notation[1], notation[2])
def expr(notation):
op, n1, n2 = extract(notation)
if op == 'num':
return do_num(n1)
elif op == '+':
return do_add(n1, n2)
elif op == '-':
return do_minus(n1, n2)
elif op == '*':
return do_mul(n1, n2)
'''
expr(num('1'))
case 'num'
do_num
'''
def num(n):
return ('num', n)
def plus(n1, n2):
return ('+', n1, n2)
def minus(n1, n2):
return ('-', n1, n2)
def mul(n1, n2):
return ('*', n1, n2)
def do_num(n1):
return int(n1)
def do_add(n1, n2):
return expr(n1) + expr(n2)
def do_minus(n1, n2):
return expr(n1) - expr(n2)
def do_mul(n1, n2):
return expr(n1) * expr(n2)
# 1
n1 = num('1')
assert(1 == expr(n1))
# 1+2
n1 = num('1')
n2 = num('2')
assert(3 == expr(plus(n1, n2)))
# 1+2+3
n1 = num('1')
n2 = num('2')
n3 = plus(n1, n2)
n4 = num('3')
assert(6 == expr(plus(n3, n4)))
# 3-2
n1 = num('3')
n2 = num('2')
assert(1 == expr(minus(n1, n2)))
# 1+2-3
n1 = num('1')
n2 = num('2')
n3 = plus(n1, n2)
n4 = num('3')
assert(0 == expr(minus(n3, n4)))
# 1*2
n1 = num('3')
n2 = num('2')
assert(6 == expr(mul(n1, n2)))