UVALive 4048 Fund Management

题目大意:你有c美元,没有股票,给你m天时间和n支股票供你买卖,每天只能买卖一手股票,对于第i支股票,知道他的第j天的每股价格为Pij,一手为si股,且每天最多能持有ki手这支股票,且所持股票的总手数不能超过k,问你m天后最多能得到的钱(最后一天结束时不能持有任何股票),并打印路径。

思路:一看数据范围,很容易想到是状压。我的思路是开d[ i ][ j ] 表示前i天手上每支股票的手数为j,则 j 相当为 9 进制表示的数,用来表示各自数量。用 map 来存 d 数组,再记忆化搜索,状态转移时注意判断条件是否满足。

哎,搞了一个下午,一直WA,写博客的心情都没了,TLE么,那就算了。。 T^T

先贴个WA的代码,挖个坑,以后再来填吧。。 希望各位路过的大神能不吝赐教。。 T^T

代码如下(WA ):

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
using namespace std;

#define MP make_pair

typedef __int64 lld;

const lld INF = 1e15;

const double EPS = 0.005;

struct Stock
{
    char name[11];
    int len,k;
    lld pri[111];
} stock[11];

map <int,lld> d[111];
map <int,int> vis[111];

int m,n;

int k;

int count(int s)
{
    int cnt = 0;
    for(int i = 0;i<n;i++)
    {
        cnt += s%9;
        s = s/9;
    }
    return cnt;
}

lld exp[11];

void init()
{
    exp[0] = 1;
    for(int i = 1;i<8;i++)
    {
        exp[i] = exp[i-1] * 9;
    }
}

map <int , pair<int,int> > path[111];

lld dp(int day,int s,lld c)
{
    if(vis[day][s]) return d[day][s];
    vis[day][s] = 1;
    lld &ans = d[day][s];
    if(day >= m )
    {
        if(s == 0) return ans = 0;
        else return ans = -INF;
    }
    ans = -INF;
    int cc = count(s);
    //printf("day = %d,s = %d,cc = %d\n",day,s,cc);
    if(cc<k)
    {
        int tmp = s;
        for(int i = 0;i<n;i++)
        {
            if(tmp%9<stock[i].k)
            {
                if(c >= stock[i].pri[day] && dp(day+1,s + exp[i], c - stock[i].pri[day]) - stock[i].pri[day] > ans)
                {
                    ans = dp(day+1,s + exp[i],c - stock[i].pri[day]) - stock[i].pri[day];
                    path[day][s] = MP(0,i);
                    path[day][s] = MP(0,i);
                }
            }
            tmp /= 9;
        }
    }
    int tmp = s;
    for(int i = 0;i<n;i++)
    {
        if(tmp%9>0)
        {
            if(dp(day+1,s - exp[i] , c + stock[i].pri[day]) + stock[i].pri[day] > ans)
            {
                ans = dp(day+1,s - exp[i], c + stock[i].pri[day]) + stock[i].pri[day];
                path[day][s] = MP(1,i);
            }
        }
        tmp /= 9;
    }
    if(dp(day+1,s,c) > ans)
    {
        ans = dp(day+1,s,c);
        path[day][s] = MP(-1,-1);
    }
    //printf("day = %d,s = %d,ans = %I64d\n",day,s,ans);
    return ans;
}

void print(int day,int s)
{
    if(day>=m) return ;
    int to;
    if(path[day][s].first == 0 )
    {
        printf("BUY %s\n",stock[path[day][s].second].name);
        to = s + exp[path[day][s].second];
    }
    else if(path[day][s].first == 1)
    {
        printf("SELL %s\n",stock[path[day][s].second].name);
        to = s - exp[path[day][s].second];
    }
    else
    {
        puts("HOLD");
        to = s;
    }
    print(day+1,to);
}

int main()
{
    init();
    double tmp;
    while(~scanf("%lf%d%d%d",&tmp,&m,&n,&k))
    {
        lld c = (lld)((tmp+EPS)*100);
        for(int i = 0;i<n;i++)
        {
            scanf("%s%d%d",stock[i].name,&stock[i].len,&stock[i].k);
            for(int j = 0;j<m;j++)
            {
                scanf("%lf",&tmp);
                stock[i].pri[j] = (lld)((tmp+EPS)*100)*stock[i].len;
                printf("%I64d\n",stock[i].pri[j]);
            }
        }
        for(int i = 0;i<=m;i++)
        {
            d[i].clear();
            vis[i].clear();
            path[i].clear();
        }
        printf("%.2lf\n",(double)((c + dp(0,0,c))/100.0));
        print(0,0);
    }
    return 0;
}

/*
144624.00 9 5 3
IBM 500 3
97.27 98.31 97.42 98.9 100.07 98.89 98.65 99.34 100.82
GOOG 100 1
467.59 483.26 487.19 483.58 485.5 489.46 499.72 505 504.28
JAVA 1000 2
5.54 5.69 5.6 5.65 5.73 6 6.14 6.06 6.06
MSFT 250 1
29.86 29.81 29.64 29.93 29.96 29.66 30.7 31.21 31.16
ORCL 300 3
17.51 17.68 17.64 17.86 17.82 17.77 17.39 17.5 17.3

144624.00 2 1 1
IBM 500 3
97.27 98.31

100 1 1 1
A 1 1
100

*/

再来个java的AC代码,乱搜搜来的,没仔细看,也懒得看,先贴上:

import java.io.*;
import java.util.*;

public class Main {
    static Scanner in;
    static PrintWriter out;

    long c;
    int m;
    int n;
    int k;

    class Stock {
        String name;
        int count;
        long[] price = new long[m];

        Stock(Scanner in) {
            name = in.next();
            long size = in.nextInt();
            count = in.nextInt();

            for (int i = 0; i < m; i++) {
                price[i] = Math.round(in.nextDouble() * 100) * size;
            }
        }
    }

    Stock[] stocks;

    List<Transition[]> transitions = new ArrayList<Transition[]>();
    Map<Integer, Integer> codes = new HashMap<Integer, Integer>();
    int lastCode = 0;

    int encode(int[] count) {
        int q = 0;
        for (int i = 0; i < n; i++) {
            q = q * (k + 1) + count[i];
        }
        if (!codes.containsKey(q)) {
            codes.put(q, lastCode++);
        }
        return codes.get(q);
    }

    class Transition {
        final int from;
        final int to;
        final int buy;
        final int sell;

        Transition(int from, int to, int buy, int sell) {
            this.from = from;
            this.to = to;
            this.buy = buy;
            this.sell = sell;
        }

        long cost(int day) {
            return 
                ((buy >= 0) ? stocks[buy].price[day] : 0) + 
                ((sell >= 0) ? -stocks[sell].price[day] : 0);
        }

        @Override
        public String toString() {
            if (buy >= 0) return "BUY " + stocks[buy].name;
            if (sell >= 0) return "SELL " + stocks[sell].name;
            return "HOLD";
        }
    }

    int generate(int[] count, int c) {
        int code = encode(count);
        if (transitions.size() == code) {
            List<Transition> transitions = new ArrayList<Transition>();
            this.transitions.add(null);

            transitions.add(new Transition(code, code, -1, -1));
            
            for (int i = 0; i < n; i++) {
                if (count[i] < stocks[i].count && c < k) {
                    count[i]++;
                    transitions.add(new Transition(code, generate(count, c + 1), i, -1));
                    count[i]--;
                }
                if (count[i] > 0) {
                    count[i]--;
                    transitions.add(new Transition(code, generate(count, c - 1), -1, i));
                    count[i]++;
                }
            }
            this.transitions.set(code, transitions.toArray(new Transition[transitions.size()]));
            //System.out.println(Arrays.toString(count) + " " + transitions.size());
        }
        return code;
    }

    long[][] max;

    void run() {
        c = Math.round(in.nextDouble() * 100);
        m = in.nextInt();
        n = in.nextInt();
        k = in.nextInt();

        stocks = new Stock[n];
        for (int i = 0; i < n; i++) {
            stocks[i] = new Stock(in);
        }

        int root = generate(new int[n], 0);

        max = new long[m + 1][lastCode];
        Arrays.fill(max[0], -1);
        max[0][root] = c;

        Transition[][] by = new Transition[m + 1][lastCode];

        //System.out.println(lastCode);

        for (int i = 0; i < m; i++) {
            long[] current = max[i];
            long[] next = max[i + 1];
            Arrays.fill(next, -1);
            for (int j = 0; j < lastCode; j++) {
                if (current[j] >= 0) {
                    for (Transition transition : transitions.get(j)) {
                        int to = transition.to;
                        long cost = current[j] - transition.cost(i);
                        if (next[to] < cost) {
                            next[to] = cost;
                            by[i + 1][to] = transition;
                        }
                    }
                }
            }                              
        }
        System.out.format("%.2f\n", max[m][root] / 100.0);
        List<Transition> actions = new LinkedList<Transition>();
        int j = root;
        for (int i = m; i > 0; i--) {
            actions.add(0, by[i][j]);
            j = by[i][j].from;
        }
        for (Transition t : actions) {
            System.out.println(t);
        }
    }

    public static void main(String[] args) throws IOException {
        //Locale.setDefault(Locale.US);
        in = new Scanner(System.in);
        //out = new Scanner(System.out);

//        while (in.hasNext()) {
            new Main().run();
//        }

        //out.close();
        //in.close();
    }
}


你可能感兴趣的:(UVALive 4048 Fund Management)