首先对 a、b、c排序
暴力超时
# coding=utf8
"""
给定三个整数数组
A = [A1, A2, ... AN],
B = [B1, B2, ... BN],
C = [C1, C2, ... CN],
请你统计有多少个三元组(i, j, k) 满足:
1. 1 <= i, j, k <= N
2. Ai < Bj < Ck
【样例输入
3
1 1 1
2 2 2
3 3 3
"""
import sys
N = int(input())
a = sorted(list(map(int, (input().split()))))
b = sorted(list(map(int, (input().split()))))
c = sorted(list(map(int, (input().split()))))
if len(a) != len(b) or len(b) != len(c):
print(None)
# print(a)
# print(b)
# print(c)
SUM = 0
la, lb, lc = N - 1, N - 1, N - 1
if a[0] >= b[lb] or b[0] >= c[lc]:
# a数组中最小的数大于等于b中最大的数 o
# b数组中最小的数大于等于c中最大的数 r
print(0)
sys.exit() # 退出程序
if c[0] > b[lb] and b[0] > a[la]:
# b数组中最小的数大于a中最大的数 并
# c数组中最小的数大于b中最大的数 且
print(N*N*N)
sys.exit()
for i in a:
for j in b:
for k in c:
if i < j < k:
SUM += 1
print(SUM)
二分找到max(j) 使得a[i] < b[j]
N = int(input())
a = sorted(list(map(int, (input().split()))))
b = sorted(list(map(int, (input().split()))))
c = sorted(list(map(int, (input().split()))))
if len(a) != len(b) or len(b) != len(c):
print(None)
def search_lower_idx(num, nums):
if num >= nums[-1]:
return -1
if num < nums[0]:
return 0
idx = len(nums)//2
while num > nums[idx]:
idx = (len(nums) + idx)//2
while num < nums[idx]:
idx //= 2
while num == nums[idx]:
idx += 1
return idx
SUM = 0
for la in range(N):
# print('la', la)
idxb = search_lower_idx(a[la], b)
# print('idxb', idxb)
if idxb == -1:
continue
for lb in range(idxb, N):
# print(f'lb {lb}')
idxc = search_lower_idx(b[lb], c)
# print(f'idxc {idxc}')
if idxc != -1:
SUM += N - idxc
# print(f'SUM {SUM}')
# print(f'la {la} SUM {SUM}\n')
print(SUM)
"""
4
3 4 2 1
3 4 2 1
3 4 2 1
"""
二分的方法对于某些测试不通过 为什么呢 这是为什么呢?